[Mlir-commits] [mlir] [mlir][sparse] use sparse_tensor.insert to update values in all-dense… (PR #76544)

Peiming Liu llvmlistbot at llvm.org
Thu Dec 28 16:04:50 PST 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/76544

… annotated output tensor.

>From 2440d5067f435630c4d7e1c2bf70f599a8a0c22f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 28 Dec 2023 23:35:05 +0000
Subject: [PATCH] [mlir][sparse] use sparse_tensor.insert to update values in
 all-dense annotated output tensor.

---
 .../Transforms/Sparsification.cpp             | 15 +---
 .../Transforms/Utils/CodegenEnv.cpp           | 10 ++-
 .../Transforms/Utils/CodegenEnv.h             |  5 +-
 .../Transforms/Utils/LoopEmitter.cpp          |  8 +-
 mlir/test/Dialect/SparseTensor/dense.mlir     | 85 +++++++++----------
 mlir/test/Dialect/SparseTensor/one_trip.mlir  | 18 ++--
 .../Dialect/SparseTensor/sparse_affine.mlir   | 28 +++---
 .../Dialect/SparseTensor/sparse_index.mlir    | 42 +++++----
 8 files changed, 101 insertions(+), 110 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 934e1e559f44d6..6b8b0bed7521f5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -688,13 +688,13 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
         } else {
           env.startReduc(exp, genTensorLoad(env, builder, exp));
         }
-        if (env.hasSparseOutput())
+        if (env.hasTrulySparseOutput())
           env.startValidLexInsert(
               constantI1(builder, env.op().getLoc(), false));
       } else {
         if (!env.isCustomReduc() || env.isReduc())
           genTensorStore(env, builder, exp, env.endReduc());
-        if (env.hasSparseOutput())
+        if (env.hasTrulySparseOutput())
           env.endValidLexInsert();
       }
     } else {
@@ -769,7 +769,7 @@ static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
 /// converted to a parallel operation depends on the requested strategy.
 static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
   // Reject parallelization of sparse output.
-  if (env.hasSparseOutput())
+  if (env.hasTrulySparseOutput())
     return false;
   // Parallel loops on tensor expansion can cause data races.
   if (env.isExpand())
@@ -1038,8 +1038,6 @@ static bool translateBitsToTidLvlPairs(
     SmallVectorImpl<TensorLevel> &tidLvls,
     SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
   const BitVector &simple = env.lat(li).simple;
-  const TensorId outTid = env.merger().getOutTensorID();
-  const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
 
   unsigned numloopCond = 0;
   bool hasNonUnique = false;
@@ -1116,13 +1114,6 @@ static bool translateBitsToTidLvlPairs(
         }
       });
 
-  if (isDenseLT(env.lt(outTid, curr))) {
-    // Note that we generate dense indices of the output tensor
-    // unconditionally, since they may not appear in the lattice, but may be
-    // needed for linearized env.
-    tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
-  }
-
   if (numloopCond == 0) {
     // Corner cases where the loop bound is defined by a *unused* operand, in
     // this case, we just generate a dense "fake" loop by iterating over the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
index d3de55e4d59bd8..3beed08a738c46 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
@@ -147,10 +147,18 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
 
   OpOperand *lhs = linalgOp.getDpsInitOperand(0);
   const TensorId tensor = makeTensorId(lhs->getOperandNumber());
+  auto outStt = getSparseTensorType(lhs->get());
   // An non-annotated output tensor is assumed dense, and becomes a random
   // access n-dim memref. Admissible since insertions cannot occur.
-  if (getSparseTensorType(lhs->get()).isAllDense())
+  if (outStt.isAllDense()) {
+    // We treat "all dense" annotated tensor as a "sparse" tensor and handle it
+    // in a unified way as "truly sparse" tensor, which avoids extra code to
+    // handle corner cases introduced by the use of "all dense" annotated
+    // tensors.
+    if (outStt.hasEncoding())
+      sparseOut = lhs;
     return true;
+  }
 
   // A tensor expression with a sparse output tensor that changes its values
   // but not its nonzero structure, an operation called "simply dynamic" in
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
index 728af841cc7b17..f1bd45c7a8f3de 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -125,7 +125,10 @@ class CodegenEnv {
   // Sparse tensor output and expansion methods.
   //
 
-  bool hasSparseOutput() const { return sparseOut != nullptr; }
+  bool hasTrulySparseOutput() const {
+    return sparseOut != nullptr &&
+           !getSparseTensorType(sparseOut->get()).isAllDense();
+  }
   bool isSparseOutput(OpOperand *o) const { return sparseOut == o; }
 
   Value getInsertionChain() const { return insChain; }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 80dad064676220..ccbbacc2ac379a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -1339,13 +1339,7 @@ void LoopEmitter::enterTensorsAtDenseLvls(
       assert(dependentLvlMap[tid][lvl].empty());
       auto enc = getSparseTensorEncoding(tensors[tid].getType());
       if (enc && !isSparseOutput(tid)) {
-        bool validPos = lvl == 0 || posits[tid][lvl - 1];
-        if (!validPos) {
-          // We might not find the pos for the sparse output tensor as it is
-          // unconditionally required by the sparsification.
-          assert(isOutputTensor(tid));
-          continue;
-        }
+        assert(lvl == 0 || posits[tid][lvl - 1]);
         posits[tid][lvl] = genAddress(builder, loc, tid, lvl, iv);
         // NOTE: we can also prepare for next lvl here in advance
       }
diff --git a/mlir/test/Dialect/SparseTensor/dense.mlir b/mlir/test/Dialect/SparseTensor/dense.mlir
index 2d8dcfea9adc19..a254161dd08984 100644
--- a/mlir/test/Dialect/SparseTensor/dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/dense.mlir
@@ -71,27 +71,27 @@ func.func @dense1(%arga: tensor<32x16xf32, #DenseMatrix>,
 // Test with a non-annotated dense matrix as input and
 // an all-dense annotated "sparse" matrix as output.
 //
-// CHECK-LABEL:   func @dense2(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16xf32>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16xf32, #sparse{{[0-9]*}}>) -> tensor<32x16xf32, #sparse{{[0-9]*}}> {
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
+
+// CHECK-LABEL:   func.func @dense2(
+// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<32x16xf32>,
+// CHECK-SAME:                      %[[VAL_1:.*]]: tensor<32x16xf32, #sparse{{[0-9]*}}>) -> tensor<32x16xf32, #sparse{{[0-9]*}}> {
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 32 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK:           %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_0]] : memref<32x16xf32>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
-// CHECK:             scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK:               %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
-// CHECK:               %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_10]] : index
-// CHECK:               %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32>
-// CHECK:               %[[VAL_14:.*]] = arith.addf %[[VAL_13]], %[[VAL_2]] : f32
-// CHECK:               memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
-// CHECK:             }
-// CHECK:           }
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.load %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
-// CHECK:           return %[[VAL_15]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK:           %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] iter_args(%[[VAL_10:.*]] = %[[VAL_1]]) -> (tensor<32x16xf32, #sparse{{[0-9]*}}>) {
+// CHECK:             %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (tensor<32x16xf32, #sparse{{[0-9]*}}>) {
+// CHECK:               %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_12]]] : memref<32x16xf32>
+// CHECK:               %[[VAL_15:.*]] = arith.addf %[[VAL_14]], %[[VAL_6]] : f32
+// CHECK:               %[[VAL_16:.*]] = sparse_tensor.insert %[[VAL_15]] into %[[VAL_13]]{{\[}}%[[VAL_9]], %[[VAL_12]]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK:               scf.yield %[[VAL_16]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK:             } {"Emitted from" = "linalg.generic"}
+// CHECK:             scf.yield %[[VAL_11]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK:           } {"Emitted from" = "linalg.generic"}
+// CHECK:           %[[VAL_17:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK:           return %[[VAL_17]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
 // CHECK:         }
 func.func @dense2(%arga: tensor<32x16xf32>,
                   %argx: tensor<32x16xf32, #DenseMatrix>)
@@ -114,31 +114,30 @@ func.func @dense2(%arga: tensor<32x16xf32>,
 // The missing innermost "k" index (due to a reduction) is accounted
 // for by scalarizing the reduction operation for the output tensor.
 //
-// CHECK-LABEL:   func @dense3(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16xf32, #sparse{{[0-9]*}}>) -> tensor<32x16xf32, #sparse{{[0-9]*}}> {
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 8 : index
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK-LABEL:   func.func @dense3(
+// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                      %[[VAL_1:.*]]: tensor<32x16xf32, #sparse{{[0-9]*}}>) -> tensor<32x16xf32, #sparse{{[0-9]*}}> {
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 32 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 8 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_0]] : memref<32x16x8xf32>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
-// CHECK:             scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK:               %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
-// CHECK:               %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_10]] : index
-// CHECK:               %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
-// CHECK:               %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
-// CHECK:                 %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]], %[[VAL_15]]] : memref<32x16x8xf32>
-// CHECK:                 %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : f32
-// CHECK:                 scf.yield %[[VAL_18]] : f32
-// CHECK:               }
-// CHECK:               memref.store %[[VAL_19:.*]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
-// CHECK:             }
-// CHECK:           }
-// CHECK:           %[[VAL_20:.*]] = sparse_tensor.load %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
-// CHECK:           return %[[VAL_20]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<32x16x8xf32>
+// CHECK:           %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] iter_args(%[[VAL_11:.*]] = %[[VAL_1]]) -> (tensor<32x16xf32, #sparse{{[0-9]*}}>) {
+// CHECK:             %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (tensor<32x16xf32, #sparse{{[0-9]*}}>) {
+// CHECK:               %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] iter_args(%[[VAL_17:.*]] = %[[VAL_7]]) -> (f32) {
+// CHECK:                 %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
+// CHECK:                 %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32
+// CHECK:                 scf.yield %[[VAL_19]] : f32
+// CHECK:               } {"Emitted from" = "linalg.generic"}
+// CHECK:               %[[VAL_20:.*]] = sparse_tensor.insert %[[VAL_15]] into %[[VAL_14]]{{\[}}%[[VAL_10]], %[[VAL_13]]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK:               scf.yield %[[VAL_20]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK:             } {"Emitted from" = "linalg.generic"}
+// CHECK:             scf.yield %[[VAL_12]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK:           } {"Emitted from" = "linalg.generic"}
+// CHECK:           %[[VAL_21:.*]] = sparse_tensor.load %[[VAL_9]] hasInserts : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK:           return %[[VAL_21]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
 // CHECK:         }
 func.func @dense3(%arga: tensor<32x16x8xf32>,
                   %argx: tensor<32x16xf32, #DenseMatrix>)
diff --git a/mlir/test/Dialect/SparseTensor/one_trip.mlir b/mlir/test/Dialect/SparseTensor/one_trip.mlir
index 9d2a125dbef171..ab10222d733726 100644
--- a/mlir/test/Dialect/SparseTensor/one_trip.mlir
+++ b/mlir/test/Dialect/SparseTensor/one_trip.mlir
@@ -12,16 +12,14 @@
   doc = "X(i,j) = X(i,j) * 2.0"
 }
 
-// CHECK-LABEL: func.func @sparse_scale(
-// CHECK-SAME:    %[[VAL_0:.*]]: tensor<1x1xf32, #sparse{{[0-9]*}}>)
-// CHECK-DAG:     %[[VAL_1:.*]] = arith.constant 0 : index
-// CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK:         %[[VAL_3:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<1x1xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK:         %[[VAL_4:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<?xf32>
-// CHECK:         %[[VAL_5:.*]] = arith.mulf %[[VAL_4]], %[[VAL_2]] : f32
-// CHECK:         memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<?xf32>
-// CHECK:         %[[VAL_6:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<1x1xf32, #sparse{{[0-9]*}}>
-// CHECK:         return %[[VAL_6]] : tensor<1x1xf32, #sparse{{[0-9]*}}>
+// CHECK-LABEL:   func.func @sparse_scale(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<1x1xf32, #sparse{{[0-9]*}}>) -> tensor<1x1xf32, #sparse{{[0-9]*}}> {
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.insert %[[VAL_2]] into %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : tensor<1x1xf32, #sparse{{[0-9]*}}>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts : tensor<1x1xf32, #sparse{{[0-9]*}}>
+// CHECK:           return %[[VAL_4]] : tensor<1x1xf32, #sparse{{[0-9]*}}>
+// CHECK:         }
 func.func @sparse_scale(%argx: tensor<1x1xf32, #Dense>) -> tensor<1x1xf32, #Dense> {
   %c = arith.constant 2.0 : f32
   %0 = linalg.generic #trait_scale
diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
index aa75e20460f51f..05fa4ca22c1b83 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
@@ -62,20 +62,20 @@ func.func @mul_inv_dense1d(%arga: tensor<32xf32, #SpVec>,
 // CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<32xf32, #sparse{{[0-9]*}}>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<4xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_6]] : tensor<32xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xf32>
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] {
-// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_11]]] : memref<?xf32>
-// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xf32>
-// CHECK:             %[[VAL_14:.*]] = arith.mulf %[[VAL_13]], %[[VAL_10]] : f32
-// CHECK:             %[[VAL_15:.*]] = arith.addf %[[VAL_12]], %[[VAL_14]] : f32
-// CHECK:             memref.store %[[VAL_15]], %[[VAL_9]]{{\[}}%[[VAL_11]]] : memref<?xf32>
-// CHECK:           }
-// CHECK:           %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_6]] : tensor<32xf32, #sparse{{[0-9]*}}>
-// CHECK:           return %[[VAL_16]] : tensor<32xf32, #sparse{{[0-9]*}}>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_7:.*]] = tensor.empty() : tensor<32xf32, #sparse{{[0-9]*}}>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<4xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xf32>
+// CHECK:           %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] iter_args(%[[VAL_13:.*]] = %[[VAL_7]]) -> (tensor<32xf32, #sparse{{[0-9]*}}>) {
+// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
+// CHECK:             %[[VAL_15:.*]] = arith.mulf %[[VAL_14]], %[[VAL_10]] : f32
+// CHECK:             %[[VAL_16:.*]] = arith.addf %[[VAL_15]], %[[VAL_6]] : f32
+// CHECK:             %[[VAL_17:.*]] = sparse_tensor.insert %[[VAL_16]] into %[[VAL_13]]{{\[}}%[[VAL_12]]] : tensor<32xf32, #sparse{{[0-9]*}}>
+// CHECK:             scf.yield %[[VAL_17]] : tensor<32xf32, #sparse{{[0-9]*}}>
+// CHECK:           } {"Emitted from" = "linalg.generic"}
+// CHECK:           %[[VAL_18:.*]] = sparse_tensor.load %[[VAL_11]] hasInserts : tensor<32xf32, #sparse{{[0-9]*}}>
+// CHECK:           return %[[VAL_18]] : tensor<32xf32, #sparse{{[0-9]*}}>
 // CHECK:         }
 func.func @mul_inv_enc_dense1d(%arga: tensor<32xf32, #EncDenseVec>,
                             %argb: tensor<4xf32, #EncDenseVec>) -> tensor<32xf32, #EncDenseVec> {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
index b09bd0a7400941..0d8e10a8541e20 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -18,33 +18,31 @@
 }
 
 // CHECK-LABEL:   func.func @dense_index(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xi64, #sparse{{[0-9]*}}>) -> tensor<?x?xi64, #sparse{{[0-9]*}}> {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
 // CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
 // CHECK-DAG:       %[[VAL_5:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK-DAG:       %[[VAL_24:.*]] = sparse_tensor.lvl %[[VAL_5]], %[[VAL_2]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_5]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_7]] step %[[VAL_2]] {
-// CHECK:             scf.for %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_2]] {
-// CHECK:               %[[VAL_12:.*]] = arith.muli %[[VAL_8]], %[[VAL_10]] : index
-// CHECK:               %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : index
-// CHECK:               %[[VAL_14:.*]] = arith.muli %[[VAL_24]], %[[VAL_10]] : index
-// CHECK:               %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : index
-// CHECK:               %[[VAL_16:.*]] = arith.index_cast %[[VAL_11]] : index to i64
-// CHECK:               %[[VAL_17:.*]] = arith.index_cast %[[VAL_10]] : index to i64
-// CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xi64>
-// CHECK:               %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_18]] : i64
-// CHECK:               %[[VAL_20:.*]] = arith.muli %[[VAL_16]], %[[VAL_19]] : i64
-// CHECK:               memref.store %[[VAL_20]], %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xi64>
-// CHECK:             }
-// CHECK:           }
-// CHECK:           %[[VAL_21:.*]] = sparse_tensor.load %[[VAL_5]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK:           return %[[VAL_21]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK:           %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_5]]) -> (tensor<?x?xi64, #sparse{{[0-9]*}}>) {
+// CHECK:             %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_1]] to %[[VAL_7]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (tensor<?x?xi64, #sparse{{[0-9]*}}>) {
+// CHECK:               %[[VAL_15:.*]] = arith.muli %[[VAL_7]], %[[VAL_10]] : index
+// CHECK:               %[[VAL_16:.*]] = arith.addi %[[VAL_15]], %[[VAL_13]] : index
+// CHECK:               %[[VAL_17:.*]] = arith.index_cast %[[VAL_13]] : index to i64
+// CHECK:               %[[VAL_18:.*]] = arith.index_cast %[[VAL_10]] : index to i64
+// CHECK:               %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xi64>
+// CHECK:               %[[VAL_20:.*]] = arith.muli %[[VAL_18]], %[[VAL_19]] : i64
+// CHECK:               %[[VAL_21:.*]] = arith.muli %[[VAL_17]], %[[VAL_20]] : i64
+// CHECK:               %[[VAL_22:.*]] = sparse_tensor.insert %[[VAL_21]] into %[[VAL_14]]{{\[}}%[[VAL_10]], %[[VAL_13]]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK:               scf.yield %[[VAL_22]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK:             } {"Emitted from" = "linalg.generic"}
+// CHECK:             scf.yield %[[VAL_12]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK:           } {"Emitted from" = "linalg.generic"}
+// CHECK:           %[[VAL_23:.*]] = sparse_tensor.load %[[VAL_9]] hasInserts : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK:           return %[[VAL_23]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
 // CHECK:         }
 func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
                       -> tensor<?x?xi64, #DenseMatrix> {



More information about the Mlir-commits mailing list