[Mlir-commits] [mlir] de5022c - [mlir][vector] Implement unrolling of ReductionOp

Matthias Springer llvmlistbot at llvm.org
Mon Mar 14 09:21:48 PDT 2022


Author: Matthias Springer
Date: 2022-03-15T01:21:24+09:00
New Revision: de5022c7d7abdfb7720e63ac88dfc35b51eb60ed

URL: https://github.com/llvm/llvm-project/commit/de5022c7d7abdfb7720e63ac88dfc35b51eb60ed
DIFF: https://github.com/llvm/llvm-project/commit/de5022c7d7abdfb7720e63ac88dfc35b51eb60ed.diff

LOG: [mlir][vector] Implement unrolling of ReductionOp

Differential Revision: https://reviews.llvm.org/D121597

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
    mlir/test/Dialect/Vector/vector-unroll-options.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b0012924e5bae..69c2c929e42ff 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -266,7 +266,9 @@ def Vector_ContractionOp :
 def Vector_ReductionOp :
   Vector_Op<"reduction", [NoSideEffect,
      PredOpTrait<"source operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+                 TCresVTEtIsSameAsOpBase<0, 0>>,
+     DeclareOpInterfaceMethods<VectorUnrollOpInterface,
+                               ["getShapeForUnroll"]>]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$vector,
                    Optional<AnyType>:$acc)>,
     Results<(outs AnyType:$dest)> {

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 61b5c7aac9f6f..fd6fea25b4e0f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -484,6 +484,10 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
   return nullptr;
 }
 
+Optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getVectorType().getShape());
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index 7ec69f006dba5..2b2042e1f36ce 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -631,13 +631,60 @@ struct TransferWriteInsertPattern
   }
 };
 
+struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
+  UnrollReductionPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options)
+      : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
+                                PatternRewriter &rewriter) const override {
+    Optional<SmallVector<int64_t, 4>> targetShape =
+        getTargetShape(options, reductionOp);
+    if (!targetShape)
+      return failure();
+    SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
+    int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
+
+    // Create unrolled vector reduction.
+    Location loc = reductionOp.getLoc();
+    Value accumulator = nullptr;
+    for (int64_t i = 0; i < ratio; ++i) {
+      SmallVector<int64_t> offsets =
+          getVectorOffset(originalSize, *targetShape, i);
+      SmallVector<int64_t> strides(offsets.size(), 1);
+      Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, reductionOp.vector(), offsets, *targetShape, strides);
+      Operation *newOp = cloneOpWithOperandsAndTypes(
+          rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
+      Value result = newOp->getResult(0);
+
+      if (!accumulator) {
+        // This is the first reduction.
+        accumulator = result;
+      } else {
+        // On subsequent reduction, combine with the accumulator.
+        accumulator = makeArithReduction(rewriter, loc, reductionOp.kind(),
+                                         accumulator, result);
+      }
+    }
+
+    rewriter.replaceOp(reductionOp, accumulator);
+    return success();
+  }
+
+private:
+  const vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options) {
   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
                UnrollContractionPattern, UnrollElementwisePattern,
-               UnrollMultiReductionPattern>(patterns.getContext(), options);
+               UnrollReductionPattern, UnrollMultiReductionPattern>(
+      patterns.getContext(), options);
 }
 
 void mlir::vector::populatePropagateVectorDistributionPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index dd1a6fd781e47..5a0014451b2c4 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -106,3 +106,23 @@ func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
 //       CHECK:   %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
 //       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
 //       CHECK:   return %[[V2]] : vector<4xf32>
+
+// CHECK-LABEL: func @vector_reduction(
+//  CHECK-SAME:     %[[v:.*]]: vector<8xf32>
+//       CHECK:   %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2]
+//       CHECK:   %[[r0:.*]] = vector.reduction <add>, %[[s0]]
+//       CHECK:   %[[s1:.*]] = vector.extract_strided_slice %[[v]] {offsets = [2], sizes = [2]
+//       CHECK:   %[[r1:.*]] = vector.reduction <add>, %[[s1]]
+//       CHECK:   %[[add1:.*]] = arith.addf %[[r0]], %[[r1]]
+//       CHECK:   %[[s2:.*]] = vector.extract_strided_slice %[[v]] {offsets = [4], sizes = [2]
+//       CHECK:   %[[r2:.*]] = vector.reduction <add>, %[[s2]]
+//       CHECK:   %[[add2:.*]] = arith.addf %[[add1]], %[[r2]]
+//       CHECK:   %[[s3:.*]] = vector.extract_strided_slice %[[v]] {offsets = [6], sizes = [2]
+//       CHECK:   %[[r3:.*]] = vector.reduction <add>, %[[s3]]
+//       CHECK:   %[[add3:.*]] = arith.addf %[[add2]], %[[r3]]
+//       CHECK:   return %[[add3]]
+func @vector_reduction(%v : vector<8xf32>) -> f32 {
+  %0 = vector.reduction <add>, %v : vector<8xf32> into f32
+  return %0 : f32
+}
+

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 2bf5e3f1a8e7d..f139e3cdcd68e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -268,6 +268,12 @@ struct TestVectorUnrollingPatterns
                         return success(isa<arith::AddFOp, vector::FMAOp,
                                            vector::MultiDimReductionOp>(op));
                       }));
+    populateVectorUnrollPatterns(
+        patterns, UnrollVectorOptions()
+                      .setNativeShape(ArrayRef<int64_t>{2})
+                      .setFilterConstraint([](Operation *op) {
+                        return success(isa<vector::ReductionOp>(op));
+                      }));
 
     if (unrollBasedOnType) {
       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =


        


More information about the Mlir-commits mailing list