[Mlir-commits] [mlir] [mlir][sparse] refine sparse fusion with empty tensors materialization (PR #66563)
Aart Bik
llvmlistbot at llvm.org
Fri Sep 15 17:51:51 PDT 2023
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/66563
This is a minor step towards deprecating bufferization.alloc_tensor(). It replaces the examples with tensor.empty() and adjusts the underlying rewriting logic to prepare for this upcoming change.
>From afd923169445f8800365859145c8abd0823c5ef7 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 15 Sep 2023 17:22:34 -0700
Subject: [PATCH] [mlir][sparse] refine sparse fusion with empty tensors
materialization
This is a minor step towards deprecating bufferization.alloc_tensor().
It replaces the examples with tensor.empty() and adjusts the underlying
rewriting logic to prepare for this upcoming change.
---
.../Transforms/SparseTensorRewriting.cpp | 28 +++++-----
.../Dialect/SparseTensor/sparse_sddmm.mlir | 54 +++++++++----------
2 files changed, 41 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 38e6621d54b331d..08482de5879ded7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -50,8 +50,8 @@ static bool isSparseTensor(Value v) {
}
static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
-// Helper method to find zero/uninitialized allocation.
-static bool isAlloc(OpOperand *op, bool isZero) {
+// Helper method to find zero/uninitialized tensor materialization.
+static bool isMaterializing(OpOperand *op, bool isZero) {
Value val = op->get();
// Check allocation, with zero alloc when required.
if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
@@ -60,6 +60,9 @@ static bool isAlloc(OpOperand *op, bool isZero) {
return copy && isZeroValue(copy);
return !copy;
}
+ // Check for empty tensor materialization.
+ if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
+ return !isZero;
// Last resort for zero alloc: the whole value is zero.
return isZero && isZeroValue(val);
}
@@ -219,24 +222,22 @@ struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
- !isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
+ !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
!isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
return failure();
auto outputType = getRankedTensorType(op.getResult(0));
- // Yielding zero on newly allocated (all-zero) sparse tensors can be
- // optimized out directly (regardless of dynamic or static size).
+ // Yielding zero on newly materialized sparse tensor can be
+ // optimized directly (regardless of dynamic or static size).
if (getSparseTensorEncoding(outputType)) {
rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
return success();
}
- // Incorporate zero value into allocation copy.
+ // Use static zero value directly instead of materialization.
if (!outputType.hasStaticShape())
return failure();
- Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType());
- AllocTensorOp a =
- op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
- rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); });
- rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
+ Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
+ rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
+ rewriter.eraseOp(def);
return success();
}
};
@@ -286,8 +287,8 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
!prod.getResult(0).hasOneUse())
return failure();
// Sampling consumer and sum of multiplication chain producer.
- if (!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
- !isAlloc(prod.getDpsInitOperand(0), /*isZero=*/true) ||
+ if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
+ !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
!isSampling(op) || !isSumOfMul(prod))
return failure();
// Modify operand structure of producer and consumer.
@@ -327,6 +328,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
last = rewriter.clone(*acc, mapper)->getResult(0);
rewriter.create<linalg::YieldOp>(loc, last);
// Force initial value on merged allocation for dense outputs.
+ // TODO: deal with non alloc tensor here one day
if (!getSparseTensorEncoding(op.getResult(0).getType())) {
Value init = prod.getDpsInitOperand(0)
->get()
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
index 610ff30a48c4a4f..707648e42cbd849 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -21,13 +21,12 @@
}
// CHECK-LABEL: func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
-// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false]} : tensor<1024x1024xf64>
-// CHECK: return %[[VAL_1]] : tensor<1024x1024xf64>
+// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
+// CHECK: return %[[C0]] : tensor<1024x1024xf64>
// CHECK: }
func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
%cst = arith.constant 0.000000e+00 : f64
- %0 = bufferization.alloc_tensor() : tensor<1024x1024xf64>
+ %0 = tensor.empty() : tensor<1024x1024xf64>
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
@@ -40,13 +39,12 @@ func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
}
// CHECK-LABEL: func.func @fold_yield_direct_zero() -> tensor<32xf64> {
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
-// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false]} : tensor<32xf64>
-// CHECK: return %[[VAL_1]] : tensor<32xf64>
+// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
+// CHECK: return %[[C0]] : tensor<32xf64>
// CHECK: }
func.func @fold_yield_direct_zero() -> tensor<32xf64> {
%cst = arith.constant 0.000000e+00 : f64
- %0 = bufferization.alloc_tensor() : tensor<32xf64>
+ %0 = tensor.empty() : tensor<32xf64>
%1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
outs(%0 : tensor<32xf64>) {
@@ -92,9 +90,9 @@ func.func @fold_yield_direct_zero() -> tensor<32xf64> {
// CHECK: %[[VAL_32:.*]] = arith.mulf %[[VAL_30]], %[[VAL_31]] : f64
// CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_28]], %[[VAL_32]] : f64
// CHECK: memref.store %[[VAL_33]], %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_27]]] : memref<8x8xf64>
-// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: }
+// CHECK: }
+// CHECK: }
// CHECK: %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<8x8xf64>
// CHECK: return %[[VAL_34]] : tensor<8x8xf64>
// CHECK: }
@@ -123,9 +121,9 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
}
// CHECK-LABEL: func.func @sparse_sampled_dd_unfused(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>,
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> {
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
@@ -133,19 +131,19 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant true
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
// CHECK-DAG: %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64>
-// CHECK-DAG: %[[VAL_10:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK-DAG: %[[VAL_10:.*]] = tensor.empty() : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
-// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64>
+// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>) {
+// CHECK: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
+// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
// CHECK: %[[VAL_28:.*]] = scf.for %[[VAL_29:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_30:.*]] = %[[VAL_27]]) -> (index) {
// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]], %[[VAL_29]]] : memref<8x8xf64>
// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_21]]] : memref<?xindex>
@@ -170,15 +168,15 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
// CHECK: scf.yield %[[VAL_37]] : index
// CHECK: }
// CHECK: memref.store %[[VAL_44]], %[[VAL_24]]{{\[}}%[[VAL_38]]] : memref<?xf64>
-// CHECK: scf.yield %[[VAL_49:.*]] : index
+// CHECK: scf.yield %[[VAL_47]] : index
// CHECK: }
-// CHECK: scf.yield %[[VAL_50:.*]] : index
+// CHECK: scf.yield %[[VAL_35]] : index
// CHECK: }
-// CHECK: %[[VAL_51:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_52:.*]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
-// CHECK: scf.yield %[[VAL_51]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_49:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_28]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: scf.yield %[[VAL_49]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: }
-// CHECK: %[[VAL_53:.*]] = sparse_tensor.load %[[VAL_54:.*]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
-// CHECK: return %[[VAL_53]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_50:.*]] = sparse_tensor.load %[[VAL_20]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: return %[[VAL_50]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: }
func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
%arga: tensor<8x8xf64>,
@@ -194,7 +192,7 @@ func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
linalg.yield %q : f64
} -> tensor<8x8xf64>
// Sample the result with elements-wise multiplication with sparse matrix.
- %3 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM>
+ %3 = tensor.empty() : tensor<8x8xf64, #SM>
%4 = linalg.generic #trait_scale
ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
outs(%3 : tensor<8x8xf64, #SM>) {
More information about the Mlir-commits
mailing list