[Mlir-commits] [mlir] [mlir][sparse] fold sparse convert into producer generic op. (PR #89999)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 24 15:12:44 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/89999.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (+9-6) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+35-3) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+29-15) 
- (added) mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir (+48) 
- (modified) mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir (-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..550e28813b4e9b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -90,17 +90,20 @@ inline MemRefType getMemRefType(T &&t) {
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
 /// Returns true iff MLIR operand has any sparse operand.
-inline bool hasAnySparseOperand(Operation *op) {
-  return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
-    return getSparseTensorEncoding(t) != nullptr;
+inline bool hasAnySparseType(TypeRange types) {
+  return llvm::any_of(types, [](Type type) {
+    return getSparseTensorEncoding(type) != nullptr;
   });
 }
 
+/// Returns true iff MLIR operand has any sparse operand.
+inline bool hasAnySparseOperand(Operation *op) {
+  return hasAnySparseType(op->getOperands().getTypes());
+}
+
 /// Returns true iff MLIR operand has any sparse result.
 inline bool hasAnySparseResult(Operation *op) {
-  return llvm::any_of(op->getResults().getTypes(), [](Type t) {
-    return getSparseTensorEncoding(t) != nullptr;
-  });
+  return hasAnySparseType(op->getResults().getTypes());
 }
 
 /// Returns true iff MLIR operand has any sparse operand or result.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5a39dfc6207707..641dcc61d7d09c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -289,6 +289,37 @@ struct FuseExtractSliceWithConcat
   }
 };
 
+/// Rewriting rule that converts direct yield of zero with initial allocation.
+struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConvertOp op,
+                                PatternRewriter &rewriter) const override {
+    auto producer = op.getSource().getDefiningOp<GenericOp>();
+    if (!producer || producer.getDpsInits().size() != 1 ||
+        !isMaterializing(producer.getDpsInitOperand(0), false) ||
+        !producer.getResult(0).hasOneUse()) {
+      return failure();
+    }
+    rewriter.modifyOpInPlace(producer, [&]() {
+      producer.getResult(0).setType(op.getResult().getType());
+    });
+
+    Operation *materializeOp =
+        producer.getDpsInitOperand(0)->get().getDefiningOp();
+
+    rewriter.modifyOpInPlace(materializeOp, [&]() {
+      materializeOp->getResult(0).setType(op.getResult().getType());
+    });
+
+    rewriter.replaceAllOpUsesWith(op, producer);
+    op->erase();
+
+    return success();
+  }
+};
+
 /// Rewriting rule that converts direct yield of zero with initial allocation.
 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
 public:
@@ -1506,9 +1537,10 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
-  patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
-               FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
-               GenSemiRingSelect, PrintRewriter>(patterns.getContext());
+  patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
+               FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
+               GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
+      patterns.getContext());
 }
 
 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index cd046b670d9a8e..0a9bb40b458d68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -403,6 +403,22 @@ static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
   return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
 }
 
+static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
+                                  Value sparseOut, ValueRange ivs, Value v) {
+  scf::IfOp condInsert =
+      builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
+  // True branch.
+  builder.setInsertionPointToStart(condInsert.thenBlock());
+  Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
+  builder.create<scf::YieldOp>(loc, res);
+  // False branch.
+  builder.setInsertionPointToStart(condInsert.elseBlock());
+  builder.create<scf::YieldOp>(loc, sparseOut);
+  // Value assignment.
+  builder.setInsertionPointAfter(condInsert);
+  return condInsert.getResult(0);
+}
+
 /// Generates insertion code to implement dynamic tensor store.
 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
                               Value rhs) {
@@ -423,23 +439,21 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
       //     return updated chain
       //   else
       //     return unmodified chain
-      scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
-          loc, chain.getType(), env.getValidLexInsert(),
-          /*else=*/true);
-      // True branch.
-      builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
-      Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
-      builder.create<scf::YieldOp>(loc, res);
-      // False branch.
-      builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
-      builder.create<scf::YieldOp>(loc, chain);
-      // Value assignment.
-      builder.setInsertionPointAfter(ifValidLexInsert);
-      env.updateInsertionChain(ifValidLexInsert.getResult(0));
+      Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
+                                       chain, ivs, rhs);
+      env.updateInsertionChain(out);
     } else {
+      Value sparseOut;
+      if (!hasAnySparseType(env.op().getInputs().getTypes())) {
+        // This is an all-dense -> sparse kernel, test rhs != 0 before
+        // insertion.
+        Value nz = genIsNonzero(builder, loc, rhs);
+        sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
+      } else {
+        sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
+      }
       // Generates regular insertion chain.
-      env.updateInsertionChain(
-          builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
+      env.updateInsertionChain(sparseOut);
     }
     return;
   }
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
new file mode 100644
index 00000000000000..077dde230fd156
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s
+
+#trait = {
+  indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+  ],
+  iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+}
+
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
+
+// CHECK-LABEL:   func.func @test(
+// CHECK:           scf.for
+// CHECK:             scf.for
+// CHECK:               scf.for
+// CHECK:                 scf.if
+// CHECK-NEXT:               tensor.insert
+// CHECK-NEXT:               scf.yield
+// CHECK-NEXT:             else
+// CHECK-NEXT:               scf.yield
+// CHECK:                 scf.yield
+// CHECK:               scf.yield
+// CHECK:             scf.yield
+// CHECK:           sparse_tensor.load
+func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #sparse> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_0 = arith.constant 1.000000e+00 : f32
+  %cst_1 = arith.constant 1.000000e+00 : f32
+  %0 = tensor.empty() : tensor<128x32x32x1xf32>
+  %1 = linalg.generic #trait
+  ins(%arg0, %arg1, %arg2 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
+  outs(%0 : tensor<128x32x32x1xf32>) {
+    ^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
+      %3 = arith.subf %cst_0, %in_2 : f32
+      %4 = arith.mulf %in, %3 : f32
+      %5 = arith.mulf %4, %cst_1 : f32
+      %6 = arith.addf %5, %in_3 : f32
+      %7 = arith.subf %6, %cst_0 : f32
+      %8 = arith.cmpf uge, %7, %cst : f32
+      %9 = arith.uitofp %8 : i1 to f32
+      linalg.yield %9 : f32
+    } -> tensor<128x32x32x1xf32>
+  %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #sparse>
+  return %2 : tensor<128x32x32x1xf32, #sparse>
+}
diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
index bbc7f397e793fe..f2f64567d5bd01 100644
--- a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
@@ -24,7 +24,6 @@ module {
   // CHECK:       arith.constant
   // CHECK:       tensor.empty()
   // CHECK:       linalg.generic
-  // CHECK:       sparse_tensor.convert
   // CHECK:       return
   //
   func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
@@ -44,4 +43,3 @@ module {
     return %cast : tensor<10x20x30xf64, #sparse>
   }
 }
-

``````````

</details>


https://github.com/llvm/llvm-project/pull/89999


More information about the Mlir-commits mailing list