[Mlir-commits] [mlir] [mlir][sparse] fold sparse convert into producer linalg op. (PR #89999)

Peiming Liu llvmlistbot at llvm.org
Fri Apr 26 09:35:41 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/89999

>From a905cb262686a347549d3eb811359046812232e6 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 24 Apr 2024 22:10:27 +0000
Subject: [PATCH 1/2] [mlir][sparse] fold sparse convert into producer generic
 operation.

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    | 15 +++---
 .../Transforms/SparseTensorRewriting.cpp      | 38 +++++++++++++--
 .../Transforms/Sparsification.cpp             | 44 +++++++++++------
 .../fuse_sparse_convert_into_producer.mlir    | 48 +++++++++++++++++++
 .../SparseTensor/no_fold_into_consumer.mlir   |  2 -
 5 files changed, 121 insertions(+), 26 deletions(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..550e28813b4e9b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -90,17 +90,20 @@ inline MemRefType getMemRefType(T &&t) {
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
 /// Returns true iff MLIR operand has any sparse operand.
-inline bool hasAnySparseOperand(Operation *op) {
-  return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
-    return getSparseTensorEncoding(t) != nullptr;
+inline bool hasAnySparseType(TypeRange types) {
+  return llvm::any_of(types, [](Type type) {
+    return getSparseTensorEncoding(type) != nullptr;
   });
 }
 
+/// Returns true iff MLIR operand has any sparse operand.
+inline bool hasAnySparseOperand(Operation *op) {
+  return hasAnySparseType(op->getOperands().getTypes());
+}
+
 /// Returns true iff MLIR operand has any sparse result.
 inline bool hasAnySparseResult(Operation *op) {
-  return llvm::any_of(op->getResults().getTypes(), [](Type t) {
-    return getSparseTensorEncoding(t) != nullptr;
-  });
+  return hasAnySparseType(op->getResults().getTypes());
 }
 
 /// Returns true iff MLIR operand has any sparse operand or result.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5a39dfc6207707..641dcc61d7d09c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -289,6 +289,37 @@ struct FuseExtractSliceWithConcat
   }
 };
 
+/// Rewriting rule that converts direct yield of zero with initial allocation.
+struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConvertOp op,
+                                PatternRewriter &rewriter) const override {
+    auto producer = op.getSource().getDefiningOp<GenericOp>();
+    if (!producer || producer.getDpsInits().size() != 1 ||
+        !isMaterializing(producer.getDpsInitOperand(0), false) ||
+        !producer.getResult(0).hasOneUse()) {
+      return failure();
+    }
+    rewriter.modifyOpInPlace(producer, [&]() {
+      producer.getResult(0).setType(op.getResult().getType());
+    });
+
+    Operation *materializeOp =
+        producer.getDpsInitOperand(0)->get().getDefiningOp();
+
+    rewriter.modifyOpInPlace(materializeOp, [&]() {
+      materializeOp->getResult(0).setType(op.getResult().getType());
+    });
+
+    rewriter.replaceAllOpUsesWith(op, producer);
+    op->erase();
+
+    return success();
+  }
+};
+
 /// Rewriting rule that converts direct yield of zero with initial allocation.
 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
 public:
@@ -1506,9 +1537,10 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
-  patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
-               FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
-               GenSemiRingSelect, PrintRewriter>(patterns.getContext());
+  patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
+               FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
+               GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
+      patterns.getContext());
 }
 
 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index cd046b670d9a8e..0a9bb40b458d68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -403,6 +403,22 @@ static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
   return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
 }
 
+static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
+                                  Value sparseOut, ValueRange ivs, Value v) {
+  scf::IfOp condInsert =
+      builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
+  // True branch.
+  builder.setInsertionPointToStart(condInsert.thenBlock());
+  Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
+  builder.create<scf::YieldOp>(loc, res);
+  // False branch.
+  builder.setInsertionPointToStart(condInsert.elseBlock());
+  builder.create<scf::YieldOp>(loc, sparseOut);
+  // Value assignment.
+  builder.setInsertionPointAfter(condInsert);
+  return condInsert.getResult(0);
+}
+
 /// Generates insertion code to implement dynamic tensor store.
 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
                               Value rhs) {
@@ -423,23 +439,21 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
       //     return updated chain
       //   else
       //     return unmodified chain
-      scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
-          loc, chain.getType(), env.getValidLexInsert(),
-          /*else=*/true);
-      // True branch.
-      builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
-      Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
-      builder.create<scf::YieldOp>(loc, res);
-      // False branch.
-      builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
-      builder.create<scf::YieldOp>(loc, chain);
-      // Value assignment.
-      builder.setInsertionPointAfter(ifValidLexInsert);
-      env.updateInsertionChain(ifValidLexInsert.getResult(0));
+      Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
+                                       chain, ivs, rhs);
+      env.updateInsertionChain(out);
     } else {
+      Value sparseOut;
+      if (!hasAnySparseType(env.op().getInputs().getTypes())) {
+        // This is an all-dense -> sparse kernel, test rhs != 0 before
+        // insertion.
+        Value nz = genIsNonzero(builder, loc, rhs);
+        sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
+      } else {
+        sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
+      }
       // Generates regular insertion chain.
-      env.updateInsertionChain(
-          builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
+      env.updateInsertionChain(sparseOut);
     }
     return;
   }
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
new file mode 100644
index 00000000000000..077dde230fd156
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s
+
+#trait = {
+  indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+  ],
+  iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+}
+
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
+
+// CHECK-LABEL:   func.func @test(
+// CHECK:           scf.for
+// CHECK:             scf.for
+// CHECK:               scf.for
+// CHECK:                 scf.if
+// CHECK-NEXT:               tensor.insert
+// CHECK-NEXT:               scf.yield
+// CHECK-NEXT:             else
+// CHECK-NEXT:               scf.yield
+// CHECK:                 scf.yield
+// CHECK:               scf.yield
+// CHECK:             scf.yield
+// CHECK:           sparse_tensor.load
+func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #sparse> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_0 = arith.constant 1.000000e+00 : f32
+  %cst_1 = arith.constant 1.000000e+00 : f32
+  %0 = tensor.empty() : tensor<128x32x32x1xf32>
+  %1 = linalg.generic #trait
+  ins(%arg0, %arg1, %arg2 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
+  outs(%0 : tensor<128x32x32x1xf32>) {
+    ^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
+      %3 = arith.subf %cst_0, %in_2 : f32
+      %4 = arith.mulf %in, %3 : f32
+      %5 = arith.mulf %4, %cst_1 : f32
+      %6 = arith.addf %5, %in_3 : f32
+      %7 = arith.subf %6, %cst_0 : f32
+      %8 = arith.cmpf uge, %7, %cst : f32
+      %9 = arith.uitofp %8 : i1 to f32
+      linalg.yield %9 : f32
+    } -> tensor<128x32x32x1xf32>
+  %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #sparse>
+  return %2 : tensor<128x32x32x1xf32, #sparse>
+}
diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
index bbc7f397e793fe..f2f64567d5bd01 100644
--- a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
@@ -24,7 +24,6 @@ module {
   // CHECK:       arith.constant
   // CHECK:       tensor.empty()
   // CHECK:       linalg.generic
-  // CHECK:       sparse_tensor.convert
   // CHECK:       return
   //
   func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
@@ -44,4 +43,3 @@ module {
     return %cast : tensor<10x20x30xf64, #sparse>
   }
 }
-

>From c689fbbce4b6033f1ee6854fdfb3dd411303b16e Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 24 Apr 2024 23:25:10 +0000
Subject: [PATCH 2/2] address comments

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  2 +-
 .../Transforms/SparseTensorRewriting.cpp      |  2 +-
 .../fuse_sparse_convert_into_producer.mlir    | 40 ++++++++++++++---
 .../SparseTensor/no_fold_into_consumer.mlir   | 45 -------------------
 4 files changed, 37 insertions(+), 52 deletions(-)
 delete mode 100644 mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 550e28813b4e9b..b182b4c72b9535 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,7 +89,7 @@ inline MemRefType getMemRefType(T &&t) {
 /// Returns null-attribute for any type without an encoding.
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
-/// Returns true iff MLIR operand has any sparse operand.
+/// Returns true iff the type range has any sparse tensor type.
 inline bool hasAnySparseType(TypeRange types) {
   return llvm::any_of(types, [](Type type) {
     return getSparseTensorEncoding(type) != nullptr;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 641dcc61d7d09c..9a8c6422a7ff62 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -289,7 +289,7 @@ struct FuseExtractSliceWithConcat
   }
 };
 
-/// Rewriting rule that converts direct yield of zero with initial allocation.
+/// Rewriting rule that fuses sparse_tensor.convert into producer.
 struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
index 077dde230fd156..efa92e565ba575 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
@@ -1,3 +1,4 @@
+// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map  | FileCheck %s --check-prefix=CHECK-FOLD
 // RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s
 
 #trait = {
@@ -10,9 +11,12 @@
   iterator_types = ["parallel", "parallel", "parallel", "parallel"]
 }
 
-#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 
-// CHECK-LABEL:   func.func @test(
+#COO = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa))}>
+#CCCD = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
+
+// CHECK-LABEL:   func.func @fold_convert(
 // CHECK:           scf.for
 // CHECK:             scf.for
 // CHECK:               scf.for
@@ -25,7 +29,10 @@
 // CHECK:               scf.yield
 // CHECK:             scf.yield
 // CHECK:           sparse_tensor.load
-func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #sparse> {
+
+// CHECK-FOLD-LABEL:   func.func @fold_convert(
+// CHECK-FOLD-NOT:     sparse_tensor.convert
+func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #CCCD> {
   %cst = arith.constant 0.000000e+00 : f32
   %cst_0 = arith.constant 1.000000e+00 : f32
   %cst_1 = arith.constant 1.000000e+00 : f32
@@ -43,6 +50,29 @@ func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>,
       %9 = arith.uitofp %8 : i1 to f32
       linalg.yield %9 : f32
     } -> tensor<128x32x32x1xf32>
-  %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #sparse>
-  return %2 : tensor<128x32x32x1xf32, #sparse>
+  %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>
+  return %2 : tensor<128x32x32x1xf32, #CCCD>
+}
+
+
+// FIXME: The following kernel is not sparsifiable because `arith.select`
+// operations is not handled by the sparse compiler at the moment.
+//
+// CHECK-FOLD-LABEL:   func.func @fold_cast(
+// CHECK-FOLD-NOT:     sparse_tensor.convert
+func.func @fold_cast(%0: tensor<10x20x30xf64, #COO>) -> tensor<10x20x30xf64, #COO> {
+  %cst = arith.constant 0.000000e+00 : f64
+  %1 = tensor.empty() : tensor<10x20x30xf64>
+  %2 = linalg.generic { indexing_maps = [#map, #map],
+                        iterator_types = ["parallel", "parallel", "parallel"]
+                      }
+  ins (%0 : tensor<10x20x30xf64, #COO>)
+  outs(%1 : tensor<10x20x30xf64>) {
+      ^bb0(%in: f64, %out: f64):
+        %4 = arith.cmpf ugt, %in, %cst : f64
+        %5 = arith.select %4, %in, %cst : f64
+        linalg.yield %5 : f64
+  } -> tensor<10x20x30xf64>
+  %cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #COO>
+  return %cast : tensor<10x20x30xf64, #COO>
 }
diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
deleted file mode 100644
index f2f64567d5bd01..00000000000000
--- a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
+++ /dev/null
@@ -1,45 +0,0 @@
-// RUN: mlir-opt %s --canonicalize --pre-sparsification-rewrite | FileCheck %s
-
-#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
-#sparse = #sparse_tensor.encoding<{
-  map = (d0, d1, d2) ->
-          (d0 : compressed(nonunique),
-	   d1 : singleton(nonunique, soa),
-	   d2 : singleton(soa)),
-  posWidth = 64,
-  crdWidth = 64
-}>
-
-
-module {
-  //
-  // This IR should not end up in an infinite loop trying to fold
-  // the linalg producer into the tensor cast consumer (even though
-  // static sizes can fold, the different encodings cannot). The
-  // cast was sloppy to begin with (but it has been observed by
-  // external sources) and can be easily repaired by the sparsifier.
-  //
-  // CHECK-LABEL: func @avoid_fold
-  // CHECK:       arith.constant
-  // CHECK:       tensor.empty()
-  // CHECK:       linalg.generic
-  // CHECK:       return
-  //
-  func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
-    %1 = tensor.empty() : tensor<10x20x30xf64>
-    %2 = linalg.generic { indexing_maps = [#map, #map],
-                          iterator_types = ["parallel", "parallel", "parallel"]
-                        }
-    ins (%0 : tensor<10x20x30xf64, #sparse>)
-    outs(%1 : tensor<10x20x30xf64>) {
-        ^bb0(%in: f64, %out: f64):
-          %cst = arith.constant 0.000000e+00 : f64
-          %4 = arith.cmpf ugt, %in, %cst : f64
-          %5 = arith.select %4, %in, %cst : f64
-          linalg.yield %5 : f64
-    } -> tensor<10x20x30xf64>
-    %cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse>
-    return %cast : tensor<10x20x30xf64, #sparse>
-  }
-}



More information about the Mlir-commits mailing list