[Mlir-commits] [mlir] 81d0d2b - [mlir][sparse] Sparse reduction in lex order no longer produces dense output

Jim Kitchen llvmlistbot at llvm.org
Fri Feb 10 11:09:40 PST 2023


Author: Jim Kitchen
Date: 2023-02-10T13:09:28-06:00
New Revision: 81d0d2b2a068eae9692b9317bceaaea252c1bbf8

URL: https://github.com/llvm/llvm-project/commit/81d0d2b2a068eae9692b9317bceaaea252c1bbf8
DIFF: https://github.com/llvm/llvm-project/commit/81d0d2b2a068eae9692b9317bceaaea252c1bbf8.diff

LOG: [mlir][sparse] Sparse reduction in lex order no longer produces dense output

Previously, when performing a reduction on a sparse tensor, the result
would be different depending on iteration order. For expanded access pattern,
an empty row would contribute no entry in the output. For lex ordering, the
identity would end up in the output.

This code changes that behavior and keeps track of whether any entries were
actually reduced in lex ordering, making the output consistent between the
two iteration styles.

Differential Revision: https://reviews.llvm.org/D142050

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/sparse_affine.mlir
    mlir/test/Dialect/SparseTensor/sparse_out.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 4a8fbf76bd567..a8ebad3379af3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -37,7 +37,7 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
       latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(),
       topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
       expFilled(), expAdded(), expCount(), redVal(), redExp(-1u),
-      redCustom(-1u) {}
+      redCustom(-1u), redValidLexInsert() {}
 
 LogicalResult CodegenEnv::initTensorExp() {
   // Builds the tensor expression for the Linalg operation in SSA form.
@@ -70,16 +70,24 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
     function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)>
         callback) {
   SmallVector<Value> params;
-  if (isReduc())
+  if (isReduc()) {
     params.push_back(redVal);
+    if (redValidLexInsert)
+      params.push_back(redValidLexInsert);
+  } else {
+    assert(!redValidLexInsert);
+  }
   if (isExpand())
     params.push_back(expCount);
   if (insChain != nullptr)
     params.push_back(insChain);
   auto r = callback(params); // may update parameters
   unsigned i = 0;
-  if (isReduc())
+  if (isReduc()) {
     updateReduc(params[i++]);
+    if (redValidLexInsert)
+      setValidLexInsert(params[i++]);
+  }
   if (isExpand())
     updateExpandCount(params[i++]);
   if (insChain != nullptr)
@@ -225,6 +233,16 @@ Value CodegenEnv::endReduc() {
   return val;
 }
 
+void CodegenEnv::setValidLexInsert(Value val) {
+  assert(isReduc() && val);
+  redValidLexInsert = val;
+}
+
+void CodegenEnv::clearValidLexInsert() {
+  assert(!isReduc());
+  redValidLexInsert = Value();
+}
+
 void CodegenEnv::startCustomReduc(unsigned exp) {
   assert(redCustom == -1u && exp != -1u);
   redCustom = exp;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 6f16d8f8daef9..b210ca873520f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -132,6 +132,9 @@ class CodegenEnv {
   void updateReduc(Value val);
   Value getReduc() const { return redVal; }
   Value endReduc();
+  void setValidLexInsert(Value val);
+  void clearValidLexInsert();
+  Value getValidLexInsert() const { return redValidLexInsert; }
 
   void startCustomReduc(unsigned exp);
   bool isCustomReduc() const { return redCustom != -1u; }
@@ -172,6 +175,11 @@ class CodegenEnv {
   unsigned redExp;
   unsigned redCustom;
 
+  // Bookkeeping for lex insertion during reductions. Holds the runtime boolean
+  // value of whether any reduction occurred. This is only set during a
+  // reduction and cleared once the reduction is finished.
+  Value redValidLexInsert;
+
   // The root tensor expression of the kernel.
   unsigned tensorExp;
 };

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 72063a16c7f05..f3cf2ceb60726 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -651,8 +651,31 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
       indices.push_back(env.emitter().getLoopIV(i));
     }
     Value chain = env.getInsertionChain();
-    env.updateInsertionChain(
-        builder.create<InsertOp>(loc, rhs, chain, indices));
+    if (!env.getValidLexInsert()) {
+      env.updateInsertionChain(
+          builder.create<InsertOp>(loc, rhs, chain, indices));
+    } else {
+      // Generates runtime check for a valid lex during reduction,
+      // to avoid inserting the identity value for empty reductions.
+      //   if (validLexInsert) then
+      //     insert(rhs) into chain
+      //     return updated chain
+      //   else
+      //     return unmodified chain
+      scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
+          loc, chain.getType(), env.getValidLexInsert(),
+          /*else=*/true);
+      // True branch.
+      builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
+      Value res = builder.create<InsertOp>(loc, rhs, chain, indices);
+      builder.create<scf::YieldOp>(loc, res);
+      // False branch.
+      builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
+      builder.create<scf::YieldOp>(loc, chain);
+      // Value assignment.
+      builder.setInsertionPointAfter(ifValidLexInsert);
+      env.updateInsertionChain(ifValidLexInsert.getResult(0));
+    }
     return;
   }
   // Generates insertion code along expanded access pattern.
@@ -857,13 +880,16 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp,
       return;
     OpOperand *lhs = op.getDpsInitOperand(0);
     if (lhs == &t) {
-      // Start or end a scalarized reduction
+      // Start or end a scalarized reduction.
       if (atStart) {
         Value load = env.isCustomReduc() ? env.getCustomRedId()
                                          : genTensorLoad(env, builder, exp);
         env.startReduc(exp, load);
+        if (env.hasSparseOutput())
+          env.setValidLexInsert(constantI1(builder, env.op().getLoc(), false));
       } else {
         genTensorStore(env, builder, exp, env.endReduc());
+        env.clearValidLexInsert();
       }
     } else {
       // Start or end loop invariant hoisting of a tensor load.
@@ -1031,6 +1057,10 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
       if (env.isReduc()) {
         yields.push_back(env.getReduc());
         env.updateReduc(ifOp.getResult(y++));
+        if (env.getValidLexInsert()) {
+          yields.push_back(env.getValidLexInsert());
+          env.setValidLexInsert(ifOp.getResult(y++));
+        }
       }
       if (env.isExpand()) {
         yields.push_back(env.getExpandCount());
@@ -1073,8 +1103,11 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, unsigned idx,
     }
     cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
   }
-  if (env.isReduc())
+  if (env.isReduc()) {
     types.push_back(env.getReduc().getType());
+    if (env.getValidLexInsert())
+      types.push_back(env.getValidLexInsert().getType());
+  }
   if (env.isExpand())
     types.push_back(builder.getIndexType());
   if (env.getInsertionChain())
@@ -1092,6 +1125,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
   if (env.isReduc()) {
     operands.push_back(env.getReduc());
     env.updateReduc(redInput);
+    if (env.getValidLexInsert())
+      // Any overlapping indices during a reduction creates a valid lex insert.
+      operands.push_back(constantI1(builder, env.op().getLoc(), true));
   }
   if (env.isExpand()) {
     operands.push_back(env.getExpandCount());
@@ -1318,6 +1354,10 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
   // End a while-loop.
   if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
     finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp);
+  } else if (auto forOp = dyn_cast<scf::ForOp>(loop)) {
+    // Any iteration of a reduction for-loop creates a valid lex insert.
+    if (env.isReduc() && env.getValidLexInsert())
+      env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
   } else {
     needsUniv = false;
   }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
index 8af3144055c2a..3d2f70611eb6e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
@@ -306,12 +306,14 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
 // CHECK-LABEL:   func.func @mul_affine_sparse2d(
 // CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<34x19xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 32 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK:           %[[VAL_7:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 32 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[VAL_TRUE:.*]] = arith.constant true
+// CHECK-DAG:       %[[VAL_FALSE:.*]] = arith.constant false
 // CHECK:           %[[VAL_8:.*]] = bufferization.alloc_tensor() : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
@@ -330,22 +332,27 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
 // CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_18]]] : memref<?xindex>
 // CHECK:               %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_4]] : index
 // CHECK:               %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref<?xindex>
-// CHECK:               %[[VAL_30:.*]]:2 = scf.for %[[VAL_31:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_4]] iter_args(%[[VAL_32:.*]] = %[[VAL_6]], %[[VAL_33:.*]] = %[[VAL_24]]) -> (f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
+// CHECK:               %[[VAL_30:.*]]:3 = scf.for %[[VAL_31:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_4]] iter_args(%[[VAL_32:.*]] = %[[VAL_6]], %[[VAL_200:.*]] = %[[VAL_FALSE]], %[[VAL_33:.*]] = %[[VAL_24]]) -> (f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
 // CHECK:                 %[[VAL_34:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_31]]] : memref<?xindex>
 // CHECK:                 %[[VAL_35:.*]] = arith.addi %[[VAL_25]], %[[VAL_7]] : index
 // CHECK:                 %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_35]] : index
-// CHECK:                 %[[VAL_37:.*]]:2 = scf.if %[[VAL_36]] -> (f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
+// CHECK:                 %[[VAL_37:.*]]:3 = scf.if %[[VAL_36]] -> (f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
 // CHECK:                   %[[VAL_38:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_31]]] : memref<?xf64>
 // CHECK:                   %[[VAL_39:.*]] = arith.mulf %[[VAL_26]], %[[VAL_38]] : f64
 // CHECK:                   %[[VAL_40:.*]] = arith.addf %[[VAL_32]], %[[VAL_39]] : f64
-// CHECK:                   scf.yield %[[VAL_40]], %[[VAL_33]] : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:                   scf.yield %[[VAL_40]], %[[VAL_TRUE]], %[[VAL_33]] : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:                 } else {
-// CHECK:                   scf.yield %[[VAL_32]], %[[VAL_33]] : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:                   scf.yield %[[VAL_32]], %[[VAL_200]], %[[VAL_33]] : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:                 }
-// CHECK:                 scf.yield %[[VAL_41:.*]]#0, %[[VAL_41]]#1 : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:                 scf.yield %[[VAL_41:.*]]#0, %[[VAL_41]]#1, %[[VAL_41]]#2 : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:               }
+// CHECK:               %[[VAL_201:.*]] = scf.if %[[VAL_30]]#1 -> (tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
+// CHECK:                 %[[VAL_42:.*]] = sparse_tensor.insert %[[VAL_30]]#0 into %[[VAL_30]]#2{{\[}}%[[VAL_16]], %[[VAL_25]]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:                 scf.yield %[[VAL_42]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:               } else {
+// CHECK:                 scf.yield %[[VAL_30]]#2 : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:               }
-// CHECK:               %[[VAL_42:.*]] = sparse_tensor.insert %[[VAL_43:.*]]#0 into %[[VAL_43]]#1{{\[}}%[[VAL_16]], %[[VAL_25]]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:               scf.yield %[[VAL_42]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:               scf.yield %[[VAL_201]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:             }
 // CHECK:             scf.yield %[[VAL_44:.*]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 99b9dffeb990e..e6d6c2364fc16 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -153,6 +153,8 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : i32
+// CHECK-DAG:       %[[VAL_FALSE:.*]] = arith.constant false
+// CHECK-DAG:       %[[VAL_TRUE:.*]] = arith.constant true
 // CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>>
 // CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>>
 // CHECK:           %[[VAL_7:.*]] = bufferization.alloc_tensor(%[[VAL_5]], %[[VAL_6]]) : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
@@ -216,13 +218,13 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK:                   %[[VAL_71:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_58]]] : memref<?xindex>
 // CHECK:                   %[[VAL_72:.*]] = arith.addi %[[VAL_58]], %[[VAL_3]] : index
 // CHECK:                   %[[VAL_73:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_72]]] : memref<?xindex>
-// CHECK:                   %[[VAL_74:.*]]:4 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_78:.*]] = %[[VAL_59]]) : (index, index, i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) -> (index, index, i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) {
+// CHECK:                   %[[VAL_74:.*]]:5 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_200:.*]] = %[[VAL_FALSE]], %[[VAL_78:.*]] = %[[VAL_59]]) : (index, index, i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) -> (index, index, i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) {
 // CHECK:                     %[[VAL_79:.*]] = arith.cmpi ult, %[[VAL_75]], %[[VAL_70]] : index
 // CHECK:                     %[[VAL_80:.*]] = arith.cmpi ult, %[[VAL_76]], %[[VAL_73]] : index
 // CHECK:                     %[[VAL_81:.*]] = arith.andi %[[VAL_79]], %[[VAL_80]] : i1
-// CHECK:                     scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_78]] : index, index, i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                     scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_200]], %[[VAL_78]] : index, index, i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                   } do {
-// CHECK:                   ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_85:.*]]: tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>):
+// CHECK:                   ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_201:.*]]: i1, %[[VAL_85:.*]]: tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>):
 // CHECK:                     %[[VAL_86:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_82]]] : memref<?xindex>
 // CHECK:                     %[[VAL_87:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_83]]] : memref<?xindex>
 // CHECK:                     %[[VAL_88:.*]] = arith.cmpi ult, %[[VAL_87]], %[[VAL_86]] : index
@@ -230,14 +232,14 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK:                     %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index
 // CHECK:                     %[[VAL_91:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index
 // CHECK:                     %[[VAL_92:.*]] = arith.andi %[[VAL_90]], %[[VAL_91]] : i1
-// CHECK:                     %[[VAL_93:.*]]:2 = scf.if %[[VAL_92]] -> (i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) {
+// CHECK:                     %[[VAL_93:.*]]:3 = scf.if %[[VAL_92]] -> (i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) {
 // CHECK:                       %[[VAL_94:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_82]]] : memref<?xi32>
 // CHECK:                       %[[VAL_95:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_83]]] : memref<?xi32>
 // CHECK:                       %[[VAL_96:.*]] = arith.muli %[[VAL_94]], %[[VAL_95]] : i32
 // CHECK:                       %[[VAL_97:.*]] = arith.addi %[[VAL_84]], %[[VAL_96]] : i32
-// CHECK:                       scf.yield %[[VAL_97]], %[[VAL_85]] : i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                       scf.yield %[[VAL_97]], %[[VAL_TRUE]], %[[VAL_85]] : i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                     } else {
-// CHECK:                       scf.yield %[[VAL_84]], %[[VAL_85]] : i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                       scf.yield %[[VAL_84]], %[[VAL_201]], %[[VAL_85]] : i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                     }
 // CHECK:                     %[[VAL_98:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index
 // CHECK:                     %[[VAL_99:.*]] = arith.addi %[[VAL_82]], %[[VAL_3]] : index
@@ -245,10 +247,15 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK:                     %[[VAL_101:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index
 // CHECK:                     %[[VAL_102:.*]] = arith.addi %[[VAL_83]], %[[VAL_3]] : index
 // CHECK:                     %[[VAL_103:.*]] = arith.select %[[VAL_101]], %[[VAL_102]], %[[VAL_83]] : index
-// CHECK:                     scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1 : index, index, i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                     scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1, %[[VAL_104]]#2 : index, index, i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                   }
-// CHECK:                   %[[VAL_105:.*]] = sparse_tensor.insert %[[VAL_106:.*]]#2 into %[[VAL_106]]#3{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
-// CHECK:                   scf.yield %[[VAL_105]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                   %[[VAL_202:.*]] = scf.if %[[VAL_74]]#3 -> (tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) {
+// CHECK:                     %[[VAL_105:.*]] = sparse_tensor.insert %[[VAL_74]]#2 into %[[VAL_74]]#4{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                     scf.yield %[[VAL_105]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                   } else {
+// CHECK:                     scf.yield %[[VAL_74]]#4 : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                   }
+// CHECK:                   scf.yield %[[VAL_202]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                 } else {
 // CHECK:                   scf.yield %[[VAL_59]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir
index f743954bccb17..697233b0fc09a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir
@@ -172,9 +172,8 @@ module {
     // CHECK-NEXT: ( ( 6, 0, 0, 0, 0 ), ( 0, 0, 0, 5, 0 ), ( 4, 0, 0, 3, 0 ), ( 0, 2, 0, 0, 0 ), ( 0, 11, 0, 0, 0 ) )
     // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, 0, 0, 0, 0, 0, 0, 0 )
     // CHECK-NEXT: ( ( 7, 0, 0, 7, 0 ), ( 9, 0, 0, 0, 0 ), ( 8, 7, 0, 7, 0 ), ( 12, 11, 0, 11, 0 ), ( 0, 0, 0, 0, 0 ) )
-    // TODO: Update once identity values are no longer inserted for non-overlapping dot product
-    // CHECK-NEXT: ( 7, inf, inf, 7, 9, inf, inf, inf, 8, 7, inf, 7, 12, 11, inf, 11 )
-    // CHECK-NEXT: ( ( 7, inf, inf, 7, 0 ), ( 9, inf, inf, inf, 0 ), ( 8, 7, inf, 7, 0 ), ( 12, 11, inf, 11, 0 ), ( 0, 0, 0, 0, 0 ) )
+    // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, 0, 0, 0, 0, 0, 0, 0 )
+    // CHECK-NEXT: ( ( 7, 0, 0, 7, 0 ), ( 9, 0, 0, 0, 0 ), ( 8, 7, 0, 7, 0 ), ( 12, 11, 0, 11, 0 ), ( 0, 0, 0, 0, 0 ) )
     //
     call @dump_mat(%sm1) : (tensor<?x?xf64, #CSR>) -> ()
     call @dump_mat(%sm2r) : (tensor<?x?xf64, #CSR>) -> ()


        


More information about the Mlir-commits mailing list