[Mlir-commits] [mlir] 6f28fd0 - [mlir][vector] Fold 1-element reduction into extract or arith ops

Lei Zhang llvmlistbot at llvm.org
Fri Apr 22 11:25:57 PDT 2022


Author: Lei Zhang
Date: 2022-04-22T14:24:46-04:00
New Revision: 6f28fd0bf7f8a568775ef256d94b92122aa524f8

URL: https://github.com/llvm/llvm-project/commit/6f28fd0bf7f8a568775ef256d94b92122aa524f8
DIFF: https://github.com/llvm/llvm-project/commit/6f28fd0bf7f8a568775ef256d94b92122aa524f8.diff

LOG: [mlir][vector] Fold 1-element reduction into extract or arith ops

If there is only one single element in the vector, then we can
just extract the element to compute the final result.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D124129

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bc54e948800f8..1d3ac261fdfe5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -311,6 +311,7 @@ def Vector_ReductionOp :
   // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
   // operands.
   let hasCustomAssemblyFormat = 1;
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fbf6675671152..3d24b7c655cd3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -488,6 +488,45 @@ Optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+namespace {
+struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReductionOp reductionOp,
+                                PatternRewriter &rewriter) const override {
+    if (reductionOp.getVectorType().getDimSize(0) != 1)
+      return failure();
+
+    Location loc = reductionOp.getLoc();
+    Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
+                                              reductionOp.getVector(),
+                                              rewriter.getI64ArrayAttr(0));
+
+    if (Value acc = reductionOp.getAcc()) {
+      assert(reductionOp.getType().isa<FloatType>());
+      switch (reductionOp.getKind()) {
+      case CombiningKind::ADD:
+        result = rewriter.create<arith::AddFOp>(loc, result, acc);
+        break;
+      case CombiningKind::MUL:
+        result = rewriter.create<arith::MulFOp>(loc, result, acc);
+        break;
+      default:
+        assert(false && "invalid op!");
+      }
+    }
+
+    rewriter.replaceOp(reductionOp, result);
+    return success();
+  }
+};
+} // namespace
+
+void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                              MLIRContext *context) {
+  results.add<ElideSingleElementReduction>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 824c455aec716..451d1446bb1c1 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1561,3 +1561,47 @@ func.func @extract_element_splat_fold(%a : i32) -> i32 {
   %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
   return %1 : i32
 }
+
+// -----
+
+// CHECK-LABEL: func @reduce_one_element_vector_extract
+//  CHECK-SAME: (%[[V:.+]]: vector<1xf32>)
+//       CHECK:   %[[S:.+]] = vector.extract %[[V]][0] : vector<1xf32>
+//       CHECK:   return %[[S]] : f32
+func @reduce_one_element_vector_extract(%a : vector<1xf32>) -> f32 {
+  %s = vector.reduction <add>, %a : vector<1xf32> into f32
+  return %s : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduce_one_element_vector_addf
+//  CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+//       CHECK:   %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
+//       CHECK:   %[[S:.+]] = arith.addf %[[A]], %arg1 : f32
+//       CHECK:   return %[[S]]
+func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
+  %s = vector.reduction <add>, %a, %b : vector<1xf32> into f32
+  return %s : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduce_one_element_vector_mulf
+//  CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+//       CHECK:   %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
+//       CHECK:   %[[S:.+]] = arith.mulf %[[A]], %arg1 : f32
+//       CHECK:   return %[[S]]
+func @reduce_one_element_vector_mulf(%a : vector<1xf32>, %b: f32) -> f32 {
+  %s = vector.reduction <mul>, %a, %b : vector<1xf32> into f32
+  return %s : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @dont_reduce_one_element_vector
+//       CHECK: vector.reduction
+func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 {
+  %s = vector.reduction <add>, %a : vector<4xf32> into f32
+  return %s : f32
+}


        


More information about the Mlir-commits mailing list