[Mlir-commits] [mlir] cb82d37 - [mlir][sparse][vectorization] optimize reduction chains

Aart Bik llvmlistbot at llvm.org
Sat Nov 26 12:41:03 PST 2022


Author: Aart Bik
Date: 2022-11-26T12:40:51-08:00
New Revision: cb82d375a8060bd3af83b64d7d2c94f4a59d4b97

URL: https://github.com/llvm/llvm-project/commit/cb82d375a8060bd3af83b64d7d2c94f4a59d4b97
DIFF: https://github.com/llvm/llvm-project/commit/cb82d375a8060bd3af83b64d7d2c94f4a59d4b97.diff

LOG: [mlir][sparse][vectorization] optimize reduction chains

A few more dots on the i's of the sparse vectorizer.
Also makes reduction matching less brittle.

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D138513

Added: 
    mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index aed394990428d..028a471f41c8d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -101,7 +101,9 @@ static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
 }
 
 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
-/// where 'lo' denotes the current index and 'hi = lo + vl - 1'.
+/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
+/// that the sparse compiler can only generate indirect loads in
+/// the last index, i.e. back().
 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
                            Value ptr, ArrayRef<Value> idxs, Value vmask) {
   VectorType vtp = vectorType(vl, ptr);
@@ -118,7 +120,9 @@ static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
 }
 
 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
-/// where 'lo' denotes the current index and 'hi = lo + vl - 1'.
+/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
+/// that the sparse compiler can only generate indirect stores in
+/// the last index, i.e. back().
 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr,
                            ArrayRef<Value> idxs, Value vmask, Value rhs) {
   if (idxs.back().getType().isa<VectorType>()) {
@@ -132,32 +136,60 @@ static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr,
   rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs);
 }
 
-/// Maps operation to combining kind for reduction.
-static vector::CombiningKind getCombiningKind(Operation *def) {
-  if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def) ||
-      isa<arith::SubFOp>(def) || isa<arith::SubIOp>(def))
-    return vector::CombiningKind::ADD;
-  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
-    return vector::CombiningKind::MUL;
-  if (isa<arith::AndIOp>(def))
-    return vector::CombiningKind::AND;
-  if (isa<arith::OrIOp>(def))
-    return vector::CombiningKind::OR;
-  if (isa<arith::XOrIOp>(def))
-    return vector::CombiningKind::XOR;
-  llvm_unreachable("unknown reduction kind");
+/// Detects a vectorizable reduction operations and returns the
+/// combining kind of reduction on success in `kind`.
+static bool isVectorizableReduction(Value red, Value iter,
+                                    vector::CombiningKind &kind) {
+  if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
+    kind = vector::CombiningKind::ADD;
+    return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
+  }
+  if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
+    kind = vector::CombiningKind::ADD;
+    return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
+  }
+  if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
+    kind = vector::CombiningKind::ADD;
+    return subf->getOperand(0) == iter;
+  }
+  if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
+    kind = vector::CombiningKind::ADD;
+    return subi->getOperand(0) == iter;
+  }
+  if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
+    kind = vector::CombiningKind::MUL;
+    return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
+  }
+  if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
+    kind = vector::CombiningKind::MUL;
+    return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
+  }
+  if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
+    kind = vector::CombiningKind::AND;
+    return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
+  }
+  if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
+    kind = vector::CombiningKind::OR;
+    return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
+  }
+  if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
+    kind = vector::CombiningKind::XOR;
+    return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
+  }
+  return false;
 }
 
 /// Generates an initial value for a vector reduction, following the scheme
 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
 /// initial scalar value is correctly embedded in the vector reduction value,
 /// and a straightforward horizontal reduction will complete the operation.
-/// The value 'r' denotes the initial value of the accumulator. Value 'rd'
-/// denotes the accumulation operation, which is solely used here to determine
-/// the kind of combining reduction (viz. addf -> sum-accumulation).
+/// Value 'r' denotes the initial value of the reduction outside the loop.
 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
-                                VectorType vtp, Value r, Value rd) {
-  vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp());
+                                Value red, Value iter, Value r,
+                                VectorType vtp) {
+  vector::CombiningKind kind;
+  if (!isVectorizableReduction(red, iter, kind))
+    llvm_unreachable("unknown reduction");
   switch (kind) {
   case vector::CombiningKind::ADD:
   case vector::CombiningKind::XOR:
@@ -180,13 +212,6 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
   llvm_unreachable("unknown reduction kind");
 }
 
-/// Generates final value for a vector reduction.
-static Value genVectorReducEnd(PatternRewriter &rewriter, Location loc,
-                               Value vexp, Value rd) {
-  vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp());
-  return rewriter.create<vector::ReductionOp>(loc, kind, vexp);
-}
-
 /// This method is called twice to analyze and rewrite the given subscripts.
 /// The first call (!codegen) does the analysis. Then, on success, the second
 /// call (codegen) yields the proper vector form in the output parameter
@@ -379,10 +404,14 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
     if (!yield.getResults().empty()) {
       Value init = forOp.getInitArgs()[0];
       VectorType vtp = vectorType(vl, init.getType());
-      Value vinit =
-          genVectorReducInit(rewriter, loc, vtp, init, yield->getOperand(0));
+      Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
+                                       forOp.getRegionIterArg(0), init, vtp);
       forOpNew = rewriter.create<scf::ForOp>(
           loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
+      forOpNew->setAttr(
+          SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(),
+          forOp->getAttr(
+              SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()));
       rewriter.setInsertionPointToStart(forOpNew.getBody());
     } else {
       forOp.setStep(step);
@@ -395,20 +424,22 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
   // Sparse for-loops either are terminated by a non-empty yield operation
   // (reduction loop) or otherwise by a store operation (pararallel loop).
   if (!yield.getResults().empty()) {
+    // Analyze/vectorize reduction.
     if (yield->getNumOperands() != 1)
       return false;
-    Value redOp = yield->getOperand(0);
-    // Analyze/vectorize reduction.
-    // TODO: use linalg utils to verify the actual reduction?
+    Value red = yield->getOperand(0);
+    Value iter = forOp.getRegionIterArg(0);
+    vector::CombiningKind kind;
     Value vrhs;
-    if (vectorizeExpr(rewriter, forOp, vl, redOp, codegen, vmask, vrhs)) {
+    if (isVectorizableReduction(red, iter, kind) &&
+        vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
       if (codegen) {
-        Value vpass =
-            genVectorInvariantValue(rewriter, vl, forOp.getRegionIterArg(0));
+        Value partial = forOpNew.getResult(0);
+        Value vpass = genVectorInvariantValue(rewriter, vl, iter);
         Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
         rewriter.create<scf::YieldOp>(loc, vred);
         rewriter.setInsertionPointAfter(forOpNew);
-        Value vres = genVectorReducEnd(rewriter, loc, forOpNew.getResult(0), redOp);
+        Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
         // Now do some relinking (last one is not completely type safe
         // but all bad ones are removed right away). This also folds away
         // nop broadcast operations.
@@ -469,6 +500,32 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
   const VL vl;
 };
 
+/// Reduction chain cleanup.
+///   v = for { }
+///   s = vsum(v)               v = for { }
+///   u = expand(s)       ->    for (v) { }
+///   for (u) { }
+template <typename VectorOp>
+struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
+public:
+  using OpRewritePattern<VectorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(VectorOp op,
+                                PatternRewriter &rewriter) const override {
+    Value inp = op.getSource();
+    if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
+      if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
+        if (forOp->hasAttr(
+                SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) {
+          rewriter.replaceOp(op, redOp.getVector());
+          return success();
+        }
+      }
+    }
+    return failure();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -482,4 +539,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
                                                bool enableSIMDIndex32) {
   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
                               enableVLAVectorization, enableSIMDIndex32);
+  patterns.add<ReducChainRewriter<vector::InsertElementOp>,
+               ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
new file mode 100644
index 0000000000000..612927c471920
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -0,0 +1,122 @@
+// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=8" -cse | \
+// RUN:   FileCheck %s
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","compressed"]}>
+
+#trait = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // a (in)
+    affine_map<(i,j) -> (i,j)>,  // b (in)
+    affine_map<(i,j) -> ()>      // x (out)
+  ],
+  iterator_types = ["reduction", "reduction"]
+}
+
+//
+// Verifies that the SIMD reductions in the two for-loops after the
+// while-loop are chained before horizontally reducing these back to scalar.
+//
+// CHECK-LABEL:   func.func @sparse_matrix_sum(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<f64>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) -> tensor<f64> {
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant dense<0.000000e+00> : vector<8xf64>
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 64 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xf64>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_2]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_2]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xf64>
+// CHECK:           %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f64>
+// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_14]][] : memref<f64>
+// CHECK:           %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) {
+// CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK:             %[[VAL_20:.*]] = arith.addi %[[VAL_17]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK:             %[[VAL_22:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK:             %[[VAL_24:.*]]:3 = scf.while (%[[VAL_25:.*]] = %[[VAL_19]], %[[VAL_26:.*]] = %[[VAL_22]], %[[VAL_27:.*]] = %[[VAL_18]]) : (index, index, f64) -> (index, index, f64) {
+// CHECK:               %[[VAL_28:.*]] = arith.cmpi ult, %[[VAL_25]], %[[VAL_21]] : index
+// CHECK:               %[[VAL_29:.*]] = arith.cmpi ult, %[[VAL_26]], %[[VAL_23]] : index
+// CHECK:               %[[VAL_30:.*]] = arith.andi %[[VAL_28]], %[[VAL_29]] : i1
+// CHECK:               scf.condition(%[[VAL_30]]) %[[VAL_25]], %[[VAL_26]], %[[VAL_27]] : index, index, f64
+// CHECK:             } do {
+// CHECK:             ^bb0(%[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index, %[[VAL_33:.*]]: f64):
+// CHECK:               %[[VAL_34:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_31]]] : memref<?xindex>
+// CHECK:               %[[VAL_35:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xindex>
+// CHECK:               %[[VAL_36:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_34]] : index
+// CHECK:               %[[VAL_37:.*]] = arith.select %[[VAL_36]], %[[VAL_35]], %[[VAL_34]] : index
+// CHECK:               %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_37]] : index
+// CHECK:               %[[VAL_39:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_37]] : index
+// CHECK:               %[[VAL_40:.*]] = arith.andi %[[VAL_38]], %[[VAL_39]] : i1
+// CHECK:               %[[VAL_41:.*]] = scf.if %[[VAL_40]] -> (f64) {
+// CHECK:                 %[[VAL_42:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref<?xf64>
+// CHECK:                 %[[VAL_43:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:                 %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f64
+// CHECK:                 %[[VAL_45:.*]] = arith.addf %[[VAL_33]], %[[VAL_44]] : f64
+// CHECK:                 scf.yield %[[VAL_45]] : f64
+// CHECK:               } else {
+// CHECK:                 %[[VAL_46:.*]] = scf.if %[[VAL_38]] -> (f64) {
+// CHECK:                   %[[VAL_47:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref<?xf64>
+// CHECK:                   %[[VAL_48:.*]] = arith.addf %[[VAL_33]], %[[VAL_47]] : f64
+// CHECK:                   scf.yield %[[VAL_48]] : f64
+// CHECK:                 } else {
+// CHECK:                   %[[VAL_49:.*]] = scf.if %[[VAL_39]] -> (f64) {
+// CHECK:                     %[[VAL_50:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:                     %[[VAL_51:.*]] = arith.addf %[[VAL_33]], %[[VAL_50]] : f64
+// CHECK:                     scf.yield %[[VAL_51]] : f64
+// CHECK:                   } else {
+// CHECK:                     scf.yield %[[VAL_33]] : f64
+// CHECK:                   }
+// CHECK:                   scf.yield %[[VAL_52:.*]] : f64
+// CHECK:                 }
+// CHECK:                 scf.yield %[[VAL_53:.*]] : f64
+// CHECK:               }
+// CHECK:               %[[VAL_54:.*]] = arith.addi %[[VAL_31]], %[[VAL_7]] : index
+// CHECK:               %[[VAL_55:.*]] = arith.select %[[VAL_38]], %[[VAL_54]], %[[VAL_31]] : index
+// CHECK:               %[[VAL_56:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
+// CHECK:               %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
+// CHECK:               scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
+// CHECK:             } attributes {"Emitted from" = "linalg.generic"}
+// CHECK:             %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
+// CHECK:             %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
+// CHECK:               %[[VAL_64:.*]] = affine.min #map2(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
+// CHECK:               %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>
+// CHECK:               %[[VAL_66:.*]] = vector.maskedload %[[VAL_10]]{{\[}}%[[VAL_62]]], %[[VAL_65]], %[[VAL_4]] : memref<?xf64>, vector<8xi1>, vector<8xf64> into vector<8xf64>
+// CHECK:               %[[VAL_67:.*]] = arith.addf %[[VAL_63]], %[[VAL_66]] : vector<8xf64>
+// CHECK:               %[[VAL_68:.*]] = arith.select %[[VAL_65]], %[[VAL_67]], %[[VAL_63]] : vector<8xi1>, vector<8xf64>
+// CHECK:               scf.yield %[[VAL_68]] : vector<8xf64>
+// CHECK:             } {"Emitted from" = "linalg.generic"}
+// CHECK:             %[[VAL_69:.*]] = scf.for %[[VAL_70:.*]] = %[[VAL_60]]#1 to %[[VAL_23]] step %[[VAL_3]] iter_args(%[[VAL_71:.*]] = %[[VAL_61]]) -> (vector<8xf64>) {
+// CHECK:               %[[VAL_73:.*]] = affine.min #map2(%[[VAL_23]], %[[VAL_70]]){{\[}}%[[VAL_3]]]
+// CHECK:               %[[VAL_74:.*]] = vector.create_mask %[[VAL_73]] : vector<8xi1>
+// CHECK:               %[[VAL_75:.*]] = vector.maskedload %[[VAL_13]]{{\[}}%[[VAL_70]]], %[[VAL_74]], %[[VAL_4]] : memref<?xf64>, vector<8xi1>, vector<8xf64> into vector<8xf64>
+// CHECK:               %[[VAL_76:.*]] = arith.addf %[[VAL_71]], %[[VAL_75]] : vector<8xf64>
+// CHECK:               %[[VAL_77:.*]] = arith.select %[[VAL_74]], %[[VAL_76]], %[[VAL_71]] : vector<8xi1>, vector<8xf64>
+// CHECK:               scf.yield %[[VAL_77]] : vector<8xf64>
+// CHECK:             } {"Emitted from" = "linalg.generic"}
+// CHECK:             %[[VAL_78:.*]] = vector.reduction <add>, %[[VAL_69]] : vector<8xf64> into f64
+// CHECK:             scf.yield %[[VAL_78]] : f64
+// CHECK:           } {"Emitted from" = "linalg.generic"}
+// CHECK:           memref.store %[[VAL_80:.*]], %[[VAL_14]][] : memref<f64>
+// CHECK:           %[[VAL_81:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<f64>
+// CHECK:           return %[[VAL_81]] : tensor<f64>
+// CHECK:         }
+func.func @sparse_matrix_sum(%argx: tensor<f64>,
+                             %arga: tensor<64x32xf64, #SparseMatrix>,
+                             %argb: tensor<64x32xf64, #SparseMatrix>) -> tensor<f64> {
+  %0 = linalg.generic #trait
+     ins(%arga, %argb: tensor<64x32xf64, #SparseMatrix>,
+                       tensor<64x32xf64, #SparseMatrix>)
+      outs(%argx: tensor<f64>) {
+      ^bb(%a: f64, %b: f64, %x: f64):
+        %m = arith.addf %a, %b : f64
+        %t = arith.addf %x, %m : f64
+        linalg.yield %t : f64
+  } -> tensor<f64>
+  return %0 : tensor<f64>
+}


        


More information about the Mlir-commits mailing list