[clang] [mlir][sparse] refine sparse fusion with empty tensors materialization (PR #66563)

Aart Bik via cfe-commits cfe-commits at lists.llvm.org
Mon Sep 18 08:53:45 PDT 2023


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/66563

>From afd923169445f8800365859145c8abd0823c5ef7 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 15 Sep 2023 17:22:34 -0700
Subject: [PATCH] [mlir][sparse] refine sparse fusion with empty tensors
 materialization

This is a minor step towards deprecating bufferization.alloc_tensor().
It replaces the examples with tensor.empty() and adjusts the underlying
rewriting logic to prepare for this upcoming change.
---
 .../Transforms/SparseTensorRewriting.cpp      | 28 +++++-----
 .../Dialect/SparseTensor/sparse_sddmm.mlir    | 54 +++++++++----------
 2 files changed, 41 insertions(+), 41 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 38e6621d54b331d..08482de5879ded7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -50,8 +50,8 @@ static bool isSparseTensor(Value v) {
 }
 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
 
-// Helper method to find zero/uninitialized allocation.
-static bool isAlloc(OpOperand *op, bool isZero) {
+// Helper method to find zero/uninitialized tensor materialization.
+static bool isMaterializing(OpOperand *op, bool isZero) {
   Value val = op->get();
   // Check allocation, with zero alloc when required.
   if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
@@ -60,6 +60,9 @@ static bool isAlloc(OpOperand *op, bool isZero) {
       return copy && isZeroValue(copy);
     return !copy;
   }
+  // Check for empty tensor materialization.
+  if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
+    return !isZero;
   // Last resort for zero alloc: the whole value is zero.
   return isZero && isZeroValue(val);
 }
@@ -219,24 +222,22 @@ struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
   LogicalResult matchAndRewrite(GenericOp op,
                                 PatternRewriter &rewriter) const override {
     if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
-        !isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
+        !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
         !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
       return failure();
     auto outputType = getRankedTensorType(op.getResult(0));
-    // Yielding zero on newly allocated (all-zero) sparse tensors can be
-    // optimized out directly (regardless of dynamic or static size).
+    // Yielding zero on newly materialized sparse tensor can be
+    // optimized directly (regardless of dynamic or static size).
     if (getSparseTensorEncoding(outputType)) {
       rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
       return success();
     }
-    // Incorporate zero value into allocation copy.
+    // Use static zero value directly instead of materialization.
     if (!outputType.hasStaticShape())
       return failure();
-    Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType());
-    AllocTensorOp a =
-        op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
-    rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); });
-    rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
+    Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
+    rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
+    rewriter.eraseOp(def);
     return success();
   }
 };
@@ -286,8 +287,8 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
         !prod.getResult(0).hasOneUse())
       return failure();
     // Sampling consumer and sum of multiplication chain producer.
-    if (!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
-        !isAlloc(prod.getDpsInitOperand(0), /*isZero=*/true) ||
+    if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
+        !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
         !isSampling(op) || !isSumOfMul(prod))
       return failure();
     // Modify operand structure of producer and consumer.
@@ -327,6 +328,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
     last = rewriter.clone(*acc, mapper)->getResult(0);
     rewriter.create<linalg::YieldOp>(loc, last);
     // Force initial value on merged allocation for dense outputs.
+    // TODO: deal with non alloc tensor here one day
     if (!getSparseTensorEncoding(op.getResult(0).getType())) {
       Value init = prod.getDpsInitOperand(0)
                        ->get()
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
index 610ff30a48c4a4f..707648e42cbd849 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -21,13 +21,12 @@
 }
 
 // CHECK-LABEL: func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
-// CHECK:         %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
-// CHECK:         %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false]} : tensor<1024x1024xf64>
-// CHECK:         return %[[VAL_1]] : tensor<1024x1024xf64>
+// CHECK:         %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
+// CHECK:         return %[[C0]] : tensor<1024x1024xf64>
 // CHECK:       }
 func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
   %cst = arith.constant 0.000000e+00 : f64
-  %0 = bufferization.alloc_tensor() : tensor<1024x1024xf64>
+  %0 = tensor.empty() : tensor<1024x1024xf64>
   %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>,
                                         affine_map<(d0, d1) -> (d0, d1)>],
                                         iterator_types = ["parallel", "parallel"]}
@@ -40,13 +39,12 @@ func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
 }
 
 // CHECK-LABEL: func.func @fold_yield_direct_zero() -> tensor<32xf64> {
-// CHECK:         %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
-// CHECK:         %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false]} : tensor<32xf64>
-// CHECK:         return %[[VAL_1]] : tensor<32xf64>
+// CHECK:         %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
+// CHECK:         return %[[C0]] : tensor<32xf64>
 // CHECK:       }
 func.func @fold_yield_direct_zero() -> tensor<32xf64> {
   %cst = arith.constant 0.000000e+00 : f64
-  %0 = bufferization.alloc_tensor() : tensor<32xf64>
+  %0 = tensor.empty() : tensor<32xf64>
   %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
                                         iterator_types = ["parallel"]}
                                         outs(%0 : tensor<32xf64>) {
@@ -92,9 +90,9 @@ func.func @fold_yield_direct_zero() -> tensor<32xf64> {
 // CHECK:                 %[[VAL_32:.*]] = arith.mulf %[[VAL_30]], %[[VAL_31]] : f64
 // CHECK:                 %[[VAL_33:.*]] = arith.addf %[[VAL_28]], %[[VAL_32]] : f64
 // CHECK:                 memref.store %[[VAL_33]], %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_27]]] : memref<8x8xf64>
-// CHECK:               } {"Emitted from" = "linalg.generic"}
-// CHECK:             } {"Emitted from" = "linalg.generic"}
-// CHECK:           } {"Emitted from" = "linalg.generic"}
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
 // CHECK:           %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<8x8xf64>
 // CHECK:           return %[[VAL_34]] : tensor<8x8xf64>
 // CHECK:         }
@@ -123,9 +121,9 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
 }
 
 // CHECK-LABEL:   func.func @sparse_sampled_dd_unfused(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>,
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:      %[[VAL_1:.*]]: tensor<8x8xf64>,
-// CHECK-SAME:      %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> {
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> {
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 8 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -133,19 +131,19 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64>
-// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK-DAG:       %[[VAL_10:.*]] = tensor.empty() : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
-// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:           %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>) {
+// CHECK:           %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
 // CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK:             %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
+// CHECK:             %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
 // CHECK:             %[[VAL_28:.*]] = scf.for %[[VAL_29:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_30:.*]] = %[[VAL_27]]) -> (index) {
 // CHECK:               %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]], %[[VAL_29]]] : memref<8x8xf64>
 // CHECK:               %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_21]]] : memref<?xindex>
@@ -170,15 +168,15 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
 // CHECK:                   scf.yield %[[VAL_37]] : index
 // CHECK:                 }
 // CHECK:                 memref.store %[[VAL_44]], %[[VAL_24]]{{\[}}%[[VAL_38]]] : memref<?xf64>
-// CHECK:                 scf.yield %[[VAL_49:.*]] : index
+// CHECK:                 scf.yield %[[VAL_47]] : index
 // CHECK:               }
-// CHECK:               scf.yield %[[VAL_50:.*]] : index
+// CHECK:               scf.yield %[[VAL_35]] : index
 // CHECK:             }
-// CHECK:             %[[VAL_51:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_52:.*]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
-// CHECK:             scf.yield %[[VAL_51]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK:             %[[VAL_49:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_28]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:             scf.yield %[[VAL_49]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           }
-// CHECK:           %[[VAL_53:.*]] = sparse_tensor.load %[[VAL_54:.*]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
-// CHECK:           return %[[VAL_53]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK:           %[[VAL_50:.*]] = sparse_tensor.load %[[VAL_20]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           return %[[VAL_50]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:         }
 func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
                                      %arga: tensor<8x8xf64>,
@@ -194,7 +192,7 @@ func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
         linalg.yield %q : f64
   } -> tensor<8x8xf64>
   // Sample the result with elements-wise multiplication with sparse matrix.
-  %3 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM>
+  %3 = tensor.empty() : tensor<8x8xf64, #SM>
   %4 = linalg.generic #trait_scale
     ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
     outs(%3 : tensor<8x8xf64, #SM>) {



More information about the cfe-commits mailing list