[Mlir-commits] [mlir] c7bb69b - [mlir][sparse] replace zero yield generic op with copy in allocation
Aart Bik
llvmlistbot at llvm.org
Thu Aug 4 09:34:11 PDT 2022
Author: Aart Bik
Date: 2022-08-04T09:33:57-07:00
New Revision: c7bb69bc7546887b98c5cc3e5c50318f85b56eaf
URL: https://github.com/llvm/llvm-project/commit/c7bb69bc7546887b98c5cc3e5c50318f85b56eaf
DIFF: https://github.com/llvm/llvm-project/commit/c7bb69bc7546887b98c5cc3e5c50318f85b56eaf.diff
LOG: [mlir][sparse] replace zero yield generic op with copy in allocation
This prepares patterns that sometimes are generated by the front-end
and would prohibit fusion of SDDMM flavored kernels.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D131126
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index f0aef0f0f6386..b4cd6c36ced7a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -10,6 +10,8 @@
//
//===----------------------------------------------------------------------===//
+#include "CodegenUtils.h"
+
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -94,12 +96,50 @@ static bool isSumOfMul(GenericOp op) {
return false;
}
+// Helper to detect direct yield of a zero value.
+static bool isZeroYield(GenericOp op) {
+ auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
+ if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
+ if (arg.getOwner()->getParentOp() == op) {
+ OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()];
+ return matchPattern(t->get(), m_Zero()) ||
+ matchPattern(t->get(), m_AnyZeroFloat());
+ }
+ } else if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
+ return matchPattern(def, m_Zero()) || matchPattern(def, m_AnyZeroFloat());
+ }
+ return false;
+}
+
//===---------------------------------------------------------------------===//
// The actual sparse tensor rewriting rules.
//===---------------------------------------------------------------------===//
namespace {
+/// Rewriting rule that converts direct yield of zero with initial allocation.
+struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
+public:
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOp op,
+ PatternRewriter &rewriter) const override {
+ if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
+ !isAlloc(op.getOutputOperand(0), /*isZero=*/false) || !isZeroYield(op))
+ return failure();
+ auto outputType = op.getResult(0).getType().cast<RankedTensorType>();
+ if (!outputType.hasStaticShape() || getSparseTensorEncoding(outputType))
+ return failure();
+ // Incorporate zero value into allocation copy.
+ Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType());
+ AllocTensorOp a =
+ op.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
+ rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); });
+ rewriter.replaceOp(op, op.getOutputOperand(0)->get());
+ return success();
+ }
+};
+
/// Rewriting rule that converts two kernels:
///
/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
@@ -187,11 +227,13 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
rewriter.create<linalg::YieldOp>(loc, last);
// Force initial value on merged allocation for dense outputs.
if (!getSparseTensorEncoding(op.getResult(0).getType())) {
- AllocTensorOp a1 =
- prod.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
- AllocTensorOp a2 =
+ Value init = prod.getOutputOperand(0)
+ ->get()
+ .getDefiningOp<AllocTensorOp>()
+ .getCopy();
+ AllocTensorOp a =
op.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
- a2.getCopyMutable().assign(a1.getCopy());
+ rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); });
}
// Replace consumer with fused operation. Old producer
// and consumer ops will be removed by DCE.
@@ -253,7 +295,7 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
//===---------------------------------------------------------------------===//
void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
- patterns
- .add<FuseSparseMultiplyOverAdd, ReshapeRewriter<tensor::ExpandShapeOp>,
- ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+ patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
+ ReshapeRewriter<tensor::ExpandShapeOp>,
+ ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
old mode 100644
new mode 100755
index 4dd4fd328ce78..9ce52a0a3520d
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -20,6 +20,42 @@
iterator_types = ["parallel", "parallel"]
}
+// CHECK-LABEL: func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
+// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false], memory_space = 0 : ui64} : tensor<1024x1024xf64>
+// CHECK: return %[[VAL_1]] : tensor<1024x1024xf64>
+// CHECK: }
+func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
+ %cst = arith.constant 0.000000e+00 : f64
+ %0 = linalg.init_tensor [1024, 1024] : tensor<1024x1024xf64>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%cst : f64)
+ outs(%0 : tensor<1024x1024xf64>) {
+ ^bb0(%a: f64, %x: f64):
+ linalg.yield %a : f64
+ } -> tensor<1024x1024xf64>
+ return %1 : tensor<1024x1024xf64>
+}
+
+// CHECK-LABEL: func.func @fold_yield_direct_zero() -> tensor<32xf64> {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
+// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false], memory_space = 0 : ui64} : tensor<32xf64>
+// CHECK: return %[[VAL_1]] : tensor<32xf64>
+// CHECK: }
+func.func @fold_yield_direct_zero() -> tensor<32xf64> {
+ %cst = arith.constant 0.000000e+00 : f64
+ %0 = linalg.init_tensor [32] : tensor<32xf64>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ outs(%0 : tensor<32xf64>) {
+ ^bb0(%x: f64):
+ linalg.yield %cst : f64
+ } -> tensor<32xf64>
+ return %1 : tensor<32xf64>
+}
+
// CHECK-LABEL: func.func @sampled_dd_unfused(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>,
More information about the Mlir-commits
mailing list