[Mlir-commits] [mlir] c7bb69b - [mlir][sparse] replace zero yield generic op with copy in allocation

Aart Bik llvmlistbot at llvm.org
Thu Aug 4 09:34:11 PDT 2022


Author: Aart Bik
Date: 2022-08-04T09:33:57-07:00
New Revision: c7bb69bc7546887b98c5cc3e5c50318f85b56eaf

URL: https://github.com/llvm/llvm-project/commit/c7bb69bc7546887b98c5cc3e5c50318f85b56eaf
DIFF: https://github.com/llvm/llvm-project/commit/c7bb69bc7546887b98c5cc3e5c50318f85b56eaf.diff

LOG: [mlir][sparse] replace zero yield generic op with copy in allocation

This prepares patterns that sometimes are generated by the front-end
and would prohibit fusion of SDDMM flavored kernels.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D131126

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index f0aef0f0f6386..b4cd6c36ced7a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -10,6 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "CodegenUtils.h"
+
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -94,12 +96,50 @@ static bool isSumOfMul(GenericOp op) {
   return false;
 }
 
+// Helper to detect direct yield of a zero value.
+static bool isZeroYield(GenericOp op) {
+  auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
+  if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
+    if (arg.getOwner()->getParentOp() == op) {
+      OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()];
+      return matchPattern(t->get(), m_Zero()) ||
+             matchPattern(t->get(), m_AnyZeroFloat());
+    }
+  } else if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
+    return matchPattern(def, m_Zero()) || matchPattern(def, m_AnyZeroFloat());
+  }
+  return false;
+}
+
 //===---------------------------------------------------------------------===//
 // The actual sparse tensor rewriting rules.
 //===---------------------------------------------------------------------===//
 
 namespace {
 
+/// Rewriting rule that converts direct yield of zero with initial allocation.
+struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
+public:
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
+        !isAlloc(op.getOutputOperand(0), /*isZero=*/false) || !isZeroYield(op))
+      return failure();
+    auto outputType = op.getResult(0).getType().cast<RankedTensorType>();
+    if (!outputType.hasStaticShape() || getSparseTensorEncoding(outputType))
+      return failure();
+    // Incorporate zero value into allocation copy.
+    Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType());
+    AllocTensorOp a =
+        op.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
+    rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); });
+    rewriter.replaceOp(op, op.getOutputOperand(0)->get());
+    return success();
+  }
+};
+
 /// Rewriting rule that converts two kernels:
 ///
 ///      T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
@@ -187,11 +227,13 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
     rewriter.create<linalg::YieldOp>(loc, last);
     // Force initial value on merged allocation for dense outputs.
     if (!getSparseTensorEncoding(op.getResult(0).getType())) {
-      AllocTensorOp a1 =
-          prod.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
-      AllocTensorOp a2 =
+      Value init = prod.getOutputOperand(0)
+                       ->get()
+                       .getDefiningOp<AllocTensorOp>()
+                       .getCopy();
+      AllocTensorOp a =
           op.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
-      a2.getCopyMutable().assign(a1.getCopy());
+      rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); });
     }
     // Replace consumer with fused operation. Old producer
     // and consumer ops will be removed by DCE.
@@ -253,7 +295,7 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
-  patterns
-      .add<FuseSparseMultiplyOverAdd, ReshapeRewriter<tensor::ExpandShapeOp>,
-           ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
+               ReshapeRewriter<tensor::ExpandShapeOp>,
+               ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
old mode 100644
new mode 100755
index 4dd4fd328ce78..9ce52a0a3520d
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -20,6 +20,42 @@
   iterator_types = ["parallel", "parallel"]
 }
 
+// CHECK-LABEL: func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
+// CHECK:         %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
+// CHECK:         %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false], memory_space = 0 : ui64} : tensor<1024x1024xf64>
+// CHECK:         return %[[VAL_1]] : tensor<1024x1024xf64>
+// CHECK:       }
+func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
+  %cst = arith.constant 0.000000e+00 : f64
+  %0 = linalg.init_tensor [1024, 1024] : tensor<1024x1024xf64>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>,
+                                        affine_map<(d0, d1) -> (d0, d1)>],
+                                        iterator_types = ["parallel", "parallel"]}
+                                        ins(%cst : f64)
+                                        outs(%0 : tensor<1024x1024xf64>) {
+    ^bb0(%a: f64, %x: f64):
+      linalg.yield %a : f64
+    } -> tensor<1024x1024xf64>
+  return %1 : tensor<1024x1024xf64>
+}
+
+// CHECK-LABEL: func.func @fold_yield_direct_zero() -> tensor<32xf64> {
+// CHECK:         %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
+// CHECK:         %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false], memory_space = 0 : ui64} : tensor<32xf64>
+// CHECK:         return %[[VAL_1]] : tensor<32xf64>
+// CHECK:       }
+func.func @fold_yield_direct_zero() -> tensor<32xf64> {
+  %cst = arith.constant 0.000000e+00 : f64
+  %0 = linalg.init_tensor [32] : tensor<32xf64>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
+                                        iterator_types = ["parallel"]}
+                                        outs(%0 : tensor<32xf64>) {
+    ^bb0(%x: f64):
+      linalg.yield %cst : f64
+    } -> tensor<32xf64>
+  return %1 : tensor<32xf64>
+}
+
 // CHECK-LABEL: func.func @sampled_dd_unfused(
 // CHECK-SAME:    %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
 // CHECK-SAME:    %[[VAL_1:.*]]: tensor<8x8xf64>,


        


More information about the Mlir-commits mailing list