[Mlir-commits] [mlir] 36b66ab - [mlir][sparse] add support for "simply dynamic" sparse tensor expressions

Aart Bik llvmlistbot at llvm.org
Tue Jun 22 13:37:51 PDT 2021


Author: Aart Bik
Date: 2021-06-22T13:37:32-07:00
New Revision: 36b66ab9ed4f5eac721b3faea1f5b0bddd29c95b

URL: https://github.com/llvm/llvm-project/commit/36b66ab9ed4f5eac721b3faea1f5b0bddd29c95b
DIFF: https://github.com/llvm/llvm-project/commit/36b66ab9ed4f5eac721b3faea1f5b0bddd29c95b.diff

LOG: [mlir][sparse] add support for "simply dynamic" sparse tensor expressions

Slowly we are moving toward full support of sparse tensor *outputs*. First
step was support for all-dense annotated "sparse" tensors. This step adds
support for truly sparse tensors, but only for operations in which the values
of a tensor change, but not the nonzero structure (this was refered to as
"simply dynamic" in the [Bik96] thesis).

Some background text was posted on discourse:
https://llvm.discourse.group/t/sparse-tensors-in-mlir/3389/25

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D104577

Added: 
    mlir/test/Dialect/SparseTensor/sparse_out.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 32e433912b07c..7bde51d2dbab1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -243,19 +243,10 @@ static LogicalResult verify(ToValuesOp op) {
   return success();
 }
 
-// TODO: generalize this beyond all-dense linearized "sparse" tensors
 static LogicalResult verify(ToTensorOp op) {
-  if (op.getNumOperands() != 1)
-    return op.emitError("expected single values array");
-  if (auto e = getSparseTensorEncoding(op.result().getType())) {
-    auto dlt = e.getDimLevelType();
-    for (unsigned i = 0, sz = dlt.size(); i < sz; i++) {
-      if (dlt[i] != SparseTensorEncodingAttr::DimLevelType::Dense)
-        return op.emitError("unexpected non-dense dimension");
-    }
-    return success();
-  }
-  return op.emitError("expected a sparse tensor as result");
+  if (!getSparseTensorEncoding(op.result().getType()))
+    return op.emitError("expected a sparse tensor as result");
+  return success();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index e5fa176494ae0..6446bdb66d37e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -276,17 +276,27 @@ class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   // Simply fold the operator into the pointer to the sparse storage scheme.
-  // TODO: generalize this beyond all-dense linearized "sparse" tensors
   matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (auto call = operands[0].getDefiningOp<CallOp>()) {
-      Value arg = call.getOperand(0);
-      if (arg.getType().isa<LLVM::LLVMPointerType>()) {
-        rewriter.replaceOp(op, arg);
-        return success();
+    // Check that all arguments of the tensor reconstruction operators are calls
+    // into the support library that query exactly the same opaque pointer.
+    Value ptr;
+    for (Value op : operands) {
+      if (auto call = op.getDefiningOp<CallOp>()) {
+        Value arg = call.getOperand(0);
+        if (!arg.getType().isa<LLVM::LLVMPointerType>())
+          return failure();
+        if (!ptr)
+          ptr = arg;
+        else if (arg != ptr)
+          return failure();
       }
     }
-    return failure();
+    // If a single opaque pointer is found, perform the folding.
+    if (!ptr)
+      return failure();
+    rewriter.replaceOp(op, ptr);
+    return success();
   }
 };
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index c3defdc4c5469..9c406a36f0728 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -367,7 +367,6 @@ static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
 /// Fills the per-dimension sparsity information for all tensors.
 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
   bool annotated = false;
-  OpOperand *lhs = op.getOutputOperand(0);
   for (OpOperand *t : op.getInputAndOutputOperands()) {
     auto map = op.getTiedIndexingMap(t);
     if (!map.isProjectedPermutation())
@@ -378,12 +377,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
     assert(map.getNumResults() == op.getRank(t));
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       unsigned idx = map.getDimPosition(perm(enc, d));
-      Dim dim = toDim(enc, d);
       merger.setDim(t->getOperandNumber(), idx, toDim(enc, d));
-      // Accept only all-dense annotated "sparse" output.
-      // TODO: support truly sparse outputs too
-      if (t == lhs && dim != Dim::kDense)
-        return false;
     }
   }
   return annotated;
@@ -497,6 +491,55 @@ static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
   return None;
 }
 
+/// Returns true if given tensor co-iterates with conjunction only.
+/// For the output tensor, this defines a "simply dynamic" operation.
+/// For instance: A(I) = A(I) * B(I) * C(I)
+static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) {
+  switch (merger.exp(exp).kind) {
+  case Kind::kTensor:
+    return merger.exp(exp).e0 == tensor;
+  case Kind::kMulF:
+  case Kind::kMulI:
+    return isConjunction(merger, tensor, merger.exp(exp).e0) ||
+           isConjunction(merger, tensor, merger.exp(exp).e1);
+  default:
+    return false;
+  }
+}
+
+/// Returns true when the tensor expression is admissable for codegen.
+/// Since all sparse input tensors are admissable, we just need to check
+/// whether the output tensor in the tensor expression codegen is admissable.
+static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
+                                  unsigned exp) {
+  OpOperand *lhs = op.getOutputOperand(0);
+  unsigned tensor = lhs->getOperandNumber();
+  auto enc = getSparseTensorEncoding(lhs->get().getType());
+  // An non-annotated output tensor is assumed dense, and becomes a random
+  // access n-dim memref. Admissable since inserstions cannot occur.
+  if (!enc)
+    return true;
+  // An all-dense annotated "sparse" output tensor becomes a linearized random
+  // access 1-dim memref. Also admissable since insertions cannot occur.
+  bool allDense = true;
+  unsigned numLoops = op.iterator_types().getValue().size();
+  for (unsigned i = 0; i < numLoops; i++)
+    if (merger.isDim(tensor, i, Dim::kSparse)) {
+      allDense = false;
+      break;
+    }
+  if (allDense)
+    return true;
+  // A tensor expression with a sparse output tensor that changes its values
+  // but not its nonzero structure, an operation called "simply dynamic" in
+  // [Bik96,Ch9], is also admissable without special codegen.
+  if (isConjunction(merger, tensor, exp))
+    return true;
+  // Reject for now since this requires changes to the nonzero structure.
+  // TODO: implement "workspaces" [Kjolstad2019]
+  return false;
+}
+
 /// Builds the iteration lattices in a bottom-up traversal given the remaining
 /// tensor (sub)expression and the next loop index in the iteration graph.
 static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
@@ -1391,15 +1434,34 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
 }
 
 /// Converts the result computed by the sparse kernel into the required form.
-static void genResult(CodeGen &codegen, PatternRewriter &rewriter,
-                      linalg::GenericOp op) {
-  RankedTensorType resType = op.getOutputTensorTypes()[0];
-  Value result = codegen.buffers.back();
-  if (getSparseTensorEncoding(resType))
-    result = rewriter.create<ToTensorOp>(op.getLoc(), resType, result);
-  else
-    result =
-        rewriter.create<memref::TensorLoadOp>(op.getLoc(), resType, result);
+static void genResult(Merger &merger, CodeGen &codegen,
+                      PatternRewriter &rewriter, linalg::GenericOp op) {
+  Location loc = op.getLoc();
+  OpOperand *lhs = op.getOutputOperand(0);
+  Type resType = lhs->get().getType();
+  unsigned tensor = lhs->getOperandNumber();
+  auto map = op.getTiedIndexingMap(lhs);
+  auto enc = getSparseTensorEncoding(resType);
+  Value result = codegen.buffers.back(); // value array
+  if (enc) {
+    // The sparse annotation unambigiously defines the arrays needed
+    // to "reconstruct" the sparse tensor from the storage scheme
+    // (even though lowering should never need this eventually).
+    SmallVector<Value, 4> args;
+    for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
+      unsigned idx = map.getDimPosition(perm(enc, d));
+      if (merger.isDim(tensor, idx, Dim::kSparse)) {
+        args.push_back(codegen.pointers[tensor][idx]);
+        args.push_back(codegen.indices[tensor][idx]);
+      }
+    }
+    args.push_back(result);
+    result = rewriter.create<ToTensorOp>(loc, resType, args);
+  } else {
+    // To "reconstruct" an non-annotated tensor, sipmly load it
+    // from the bufferized value.
+    result = rewriter.create<memref::TensorLoadOp>(loc, resType, result);
+  }
   rewriter.replaceOp(op, result);
 }
 
@@ -1438,12 +1500,16 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     if (!exp.hasValue())
       return failure(); // build failure
 
+    // Reject an inadmissable tensor expression.
+    if (!isAdmissableTensorExp(merger, op, exp.getValue()))
+      return failure();
+
     // Recursively generates code.
     CodeGen codegen(options, numTensors, numLoops);
     if (!genBuffers(merger, codegen, rewriter, op))
       return failure(); // could not bufferize
     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
-    genResult(codegen, rewriter, op);
+    genResult(merger, codegen, rewriter, op);
     return success();
   }
 

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 10a385f4d0729..ffc3b4d197669 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -190,11 +190,23 @@ func @sparse_valuesi8(%arg0: tensor<128xi8, #SparseVector>) -> memref<?xi8> {
   return %0 : memref<?xi8>
 }
 
-// CHECK-LABEL: func @sparse_reconstruct(
+// CHECK-LABEL: func @sparse_reconstruct_1(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>
 //       CHECK: return %[[A]] : !llvm.ptr<i8>
-func @sparse_reconstruct(%arg0: tensor<128xf32, #DenseVector> {linalg.inplaceable = true}) -> tensor<128xf32, #DenseVector> {
+func @sparse_reconstruct_1(%arg0: tensor<128xf32, #DenseVector> {linalg.inplaceable = true}) -> tensor<128xf32, #DenseVector> {
   %0 = sparse_tensor.values %arg0 : tensor<128xf32, #DenseVector> to memref<?xf32>
   %1 = sparse_tensor.tensor %0 : memref<?xf32> to tensor<128xf32, #DenseVector>
   return %1 : tensor<128xf32, #DenseVector>
 }
+
+// CHECK-LABEL: func @sparse_reconstruct_n(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>
+//       CHECK: return %[[A]] : !llvm.ptr<i8>
+func @sparse_reconstruct_n(%arg0: tensor<128xf32, #SparseVector> {linalg.inplaceable = true}) -> tensor<128xf32, #SparseVector> {
+  %c = constant 0 : index
+  %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf32, #SparseVector> to memref<?xindex>
+  %1 = sparse_tensor.indices %arg0, %c : tensor<128xf32, #SparseVector> to memref<?xindex>
+  %2 = sparse_tensor.values %arg0 : tensor<128xf32, #SparseVector> to memref<?xf32>
+  %3 = sparse_tensor.tensor %0, %1, %2 : memref<?xindex>, memref<?xindex>, memref<?xf32> to tensor<128xf32, #SparseVector>
+  return %3 : tensor<128xf32, #SparseVector>
+}

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index febc2eb21f378..06a63cf37cf5b 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -93,26 +93,3 @@ func @sparse_to_unannotated_tensor(%arg0: memref<?xf64>) -> tensor<16x32xf64> {
   %0 = sparse_tensor.tensor %arg0 : memref<?xf64> to tensor<16x32xf64>
   return %0 : tensor<16x32xf64>
 }
-
-// -----
-
-#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","compressed"]}>
-
-func @sparse_to_sparse_tensor(%arg0: memref<?xf64>) -> tensor<16x32xf64, #SparseMatrix> {
-  // expected-error at +1 {{unexpected non-dense dimension}}
-  %0 = sparse_tensor.tensor %arg0 : memref<?xf64> to tensor<16x32xf64, #SparseMatrix>
-  return %0 : tensor<16x32xf64, #SparseMatrix>
-}
-
-// -----
-
-#DenseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","dense"]}>
-
-func @sparse_to_tensor(%arg0: memref<?xindex>,
-                       %arg1: memref<?xindex>,
-		       %arg2: memref<?xf64>) -> tensor<16x32xf64, #DenseMatrix> {
-  // expected-error at +1 {{expected single values array}}
-  %0 = sparse_tensor.tensor %arg0, %arg1, %arg2
-    : memref<?xindex>, memref<?xindex>, memref<?xf64> to tensor<16x32xf64, #DenseMatrix>
-  return %0 : tensor<16x32xf64, #DenseMatrix>
-}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
new file mode 100644
index 0000000000000..aa37991a8b538
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -0,0 +1,133 @@
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  dimOrdering = affine_map<(i,j) -> (i,j)>
+}>
+
+#DCSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ],
+  dimOrdering = affine_map<(i,j) -> (i,j)>
+}>
+
+#trait_scale = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) = X(i,j) * 2"
+}
+
+// CHECK-LABEL:   func @sparse_simply_dynamic1(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK:           %[[VAL_1:.*]] = constant 2.000000e+00 : f32
+// CHECK:           %[[VAL_2:.*]] = constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:             %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK:             scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_3]] {
+// CHECK:               %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xf32>
+// CHECK:               %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_1]] : f32
+// CHECK:               memref.store %[[VAL_17]], %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xf32>
+// CHECK:             }
+// CHECK:           }
+// CHECK:           %[[VAL_18:.*]] = sparse_tensor.tensor %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32> to tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:           return %[[VAL_18]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:         }
+func @sparse_simply_dynamic1(%argx: tensor<32x16xf32, #DCSR> {linalg.inplaceable = true}) -> tensor<32x16xf32, #DCSR> {
+  %c = constant 2.0 : f32
+  %0 = linalg.generic #trait_scale
+    outs(%argx: tensor<32x16xf32, #DCSR>) {
+      ^bb(%x: f32):
+        %1 = mulf %x, %c : f32
+        linalg.yield %1 : f32
+  } -> tensor<32x16xf32, #DCSR>
+  return %0 : tensor<32x16xf32, #DCSR>
+}
+
+#trait_elt_wise_mult = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) = A(i,j) * X(i,j)"
+}
+
+// CHECK-LABEL:   func @sparse_simply_dynamic2(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK:           %[[VAL_2:.*]] = constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_2]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_2]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
+// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK:             %[[VAL_17:.*]] = addi %[[VAL_15]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK:             %[[VAL_20:.*]] = addi %[[VAL_14]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK:             %[[VAL_22:.*]]:2 = scf.while (%[[VAL_23:.*]] = %[[VAL_16]], %[[VAL_24:.*]] = %[[VAL_19]]) : (index, index) -> (index, index) {
+// CHECK:               %[[VAL_25:.*]] = cmpi ult, %[[VAL_23]], %[[VAL_18]] : index
+// CHECK:               %[[VAL_26:.*]] = cmpi ult, %[[VAL_24]], %[[VAL_21]] : index
+// CHECK:               %[[VAL_27:.*]] = and %[[VAL_25]], %[[VAL_26]] : i1
+// CHECK:               scf.condition(%[[VAL_27]]) %[[VAL_23]], %[[VAL_24]] : index, index
+// CHECK:             } do {
+// CHECK:             ^bb0(%[[VAL_28:.*]]: index, %[[VAL_29:.*]]: index):
+// CHECK:               %[[VAL_30:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_28]]] : memref<?xindex>
+// CHECK:               %[[VAL_31:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_29]]] : memref<?xindex>
+// CHECK:               %[[VAL_32:.*]] = cmpi ult, %[[VAL_31]], %[[VAL_30]] : index
+// CHECK:               %[[VAL_33:.*]] = select %[[VAL_32]], %[[VAL_31]], %[[VAL_30]] : index
+// CHECK:               %[[VAL_34:.*]] = cmpi eq, %[[VAL_30]], %[[VAL_33]] : index
+// CHECK:               %[[VAL_35:.*]] = cmpi eq, %[[VAL_31]], %[[VAL_33]] : index
+// CHECK:               %[[VAL_36:.*]] = and %[[VAL_34]], %[[VAL_35]] : i1
+// CHECK:               scf.if %[[VAL_36]] {
+// CHECK:                 %[[VAL_37:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref<?xf32>
+// CHECK:                 %[[VAL_38:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_28]]] : memref<?xf32>
+// CHECK:                 %[[VAL_39:.*]] = mulf %[[VAL_37]], %[[VAL_38]] : f32
+// CHECK:                 memref.store %[[VAL_39]], %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref<?xf32>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:               %[[VAL_40:.*]] = cmpi eq, %[[VAL_30]], %[[VAL_33]] : index
+// CHECK:               %[[VAL_41:.*]] = addi %[[VAL_28]], %[[VAL_3]] : index
+// CHECK:               %[[VAL_42:.*]] = select %[[VAL_40]], %[[VAL_41]], %[[VAL_28]] : index
+// CHECK:               %[[VAL_43:.*]] = cmpi eq, %[[VAL_31]], %[[VAL_33]] : index
+// CHECK:               %[[VAL_44:.*]] = addi %[[VAL_29]], %[[VAL_3]] : index
+// CHECK:               %[[VAL_45:.*]] = select %[[VAL_43]], %[[VAL_44]], %[[VAL_29]] : index
+// CHECK:               scf.yield %[[VAL_42]], %[[VAL_45]] : index, index
+// CHECK:             }
+// CHECK:           }
+// CHECK:           %[[VAL_46:.*]] = sparse_tensor.tensor %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] : memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32> to tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:           return %[[VAL_46]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:         }
+func @sparse_simply_dynamic2(%arga: tensor<32x16xf32, #CSR>,
+                             %argx: tensor<32x16xf32, #DCSR> {linalg.inplaceable = true}) -> tensor<32x16xf32, #DCSR> {
+  %0 = linalg.generic #trait_elt_wise_mult
+    ins(%arga: tensor<32x16xf32, #CSR>)
+    outs(%argx: tensor<32x16xf32, #DCSR>) {
+      ^bb(%a: f32, %x: f32):
+        %1 = mulf %x, %a : f32
+        linalg.yield %1 : f32
+  } -> tensor<32x16xf32, #DCSR>
+  return %0 : tensor<32x16xf32, #DCSR>
+}


        


More information about the Mlir-commits mailing list