[Mlir-commits] [mlir] 6f28fd0 - [mlir][vector] Fold 1-element reduction into extract or arith ops
Lei Zhang
llvmlistbot at llvm.org
Fri Apr 22 11:25:57 PDT 2022
Author: Lei Zhang
Date: 2022-04-22T14:24:46-04:00
New Revision: 6f28fd0bf7f8a568775ef256d94b92122aa524f8
URL: https://github.com/llvm/llvm-project/commit/6f28fd0bf7f8a568775ef256d94b92122aa524f8
DIFF: https://github.com/llvm/llvm-project/commit/6f28fd0bf7f8a568775ef256d94b92122aa524f8.diff
LOG: [mlir][vector] Fold 1-element reduction into extract or arith ops
If there is only one single element in the vector, then we can
just extract the element to compute the final result.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D124129
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bc54e948800f8..1d3ac261fdfe5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -311,6 +311,7 @@ def Vector_ReductionOp :
// TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
// operands.
let hasCustomAssemblyFormat = 1;
+ let hasCanonicalizer = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fbf6675671152..3d24b7c655cd3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -488,6 +488,45 @@ Optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
+namespace {
+struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReductionOp reductionOp,
+ PatternRewriter &rewriter) const override {
+ if (reductionOp.getVectorType().getDimSize(0) != 1)
+ return failure();
+
+ Location loc = reductionOp.getLoc();
+ Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
+ reductionOp.getVector(),
+ rewriter.getI64ArrayAttr(0));
+
+ if (Value acc = reductionOp.getAcc()) {
+ assert(reductionOp.getType().isa<FloatType>());
+ switch (reductionOp.getKind()) {
+ case CombiningKind::ADD:
+ result = rewriter.create<arith::AddFOp>(loc, result, acc);
+ break;
+ case CombiningKind::MUL:
+ result = rewriter.create<arith::MulFOp>(loc, result, acc);
+ break;
+ default:
+ assert(false && "invalid op!");
+ }
+ }
+
+ rewriter.replaceOp(reductionOp, result);
+ return success();
+ }
+};
+} // namespace
+
+void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ElideSingleElementReduction>(context);
+}
+
//===----------------------------------------------------------------------===//
// ContractionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 824c455aec716..451d1446bb1c1 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1561,3 +1561,47 @@ func.func @extract_element_splat_fold(%a : i32) -> i32 {
%1 = vector.extractelement %v[%i : i32] : vector<4xi32>
return %1 : i32
}
+
+// -----
+
+// CHECK-LABEL: func @reduce_one_element_vector_extract
+// CHECK-SAME: (%[[V:.+]]: vector<1xf32>)
+// CHECK: %[[S:.+]] = vector.extract %[[V]][0] : vector<1xf32>
+// CHECK: return %[[S]] : f32
+func @reduce_one_element_vector_extract(%a : vector<1xf32>) -> f32 {
+ %s = vector.reduction <add>, %a : vector<1xf32> into f32
+ return %s : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduce_one_element_vector_addf
+// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
+// CHECK: %[[S:.+]] = arith.addf %[[A]], %arg1 : f32
+// CHECK: return %[[S]]
+func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
+ %s = vector.reduction <add>, %a, %b : vector<1xf32> into f32
+ return %s : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduce_one_element_vector_mulf
+// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
+// CHECK: %[[S:.+]] = arith.mulf %[[A]], %arg1 : f32
+// CHECK: return %[[S]]
+func @reduce_one_element_vector_mulf(%a : vector<1xf32>, %b: f32) -> f32 {
+ %s = vector.reduction <mul>, %a, %b : vector<1xf32> into f32
+ return %s : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @dont_reduce_one_element_vector
+// CHECK: vector.reduction
+func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 {
+ %s = vector.reduction <add>, %a : vector<4xf32> into f32
+ return %s : f32
+}
More information about the Mlir-commits
mailing list