[Mlir-commits] [mlir] de5022c - [mlir][vector] Implement unrolling of ReductionOp
Matthias Springer
llvmlistbot at llvm.org
Mon Mar 14 09:21:48 PDT 2022
Author: Matthias Springer
Date: 2022-03-15T01:21:24+09:00
New Revision: de5022c7d7abdfb7720e63ac88dfc35b51eb60ed
URL: https://github.com/llvm/llvm-project/commit/de5022c7d7abdfb7720e63ac88dfc35b51eb60ed
DIFF: https://github.com/llvm/llvm-project/commit/de5022c7d7abdfb7720e63ac88dfc35b51eb60ed.diff
LOG: [mlir][vector] Implement unrolling of ReductionOp
Differential Revision: https://reviews.llvm.org/D121597
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
mlir/test/Dialect/Vector/vector-unroll-options.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b0012924e5bae..69c2c929e42ff 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -266,7 +266,9 @@ def Vector_ContractionOp :
def Vector_ReductionOp :
Vector_Op<"reduction", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>]>,
+ TCresVTEtIsSameAsOpBase<0, 0>>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface,
+ ["getShapeForUnroll"]>]>,
Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$vector,
Optional<AnyType>:$acc)>,
Results<(outs AnyType:$dest)> {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 61b5c7aac9f6f..fd6fea25b4e0f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -484,6 +484,10 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
return nullptr;
}
+Optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getVectorType().getShape());
+}
+
//===----------------------------------------------------------------------===//
// ContractionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index 7ec69f006dba5..2b2042e1f36ce 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -631,13 +631,60 @@ struct TransferWriteInsertPattern
}
};
+struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
+ UnrollReductionPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options)
+ : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
+ PatternRewriter &rewriter) const override {
+ Optional<SmallVector<int64_t, 4>> targetShape =
+ getTargetShape(options, reductionOp);
+ if (!targetShape)
+ return failure();
+ SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
+ int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
+
+ // Create unrolled vector reduction.
+ Location loc = reductionOp.getLoc();
+ Value accumulator = nullptr;
+ for (int64_t i = 0; i < ratio; ++i) {
+ SmallVector<int64_t> offsets =
+ getVectorOffset(originalSize, *targetShape, i);
+ SmallVector<int64_t> strides(offsets.size(), 1);
+ Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.vector(), offsets, *targetShape, strides);
+ Operation *newOp = cloneOpWithOperandsAndTypes(
+ rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
+ Value result = newOp->getResult(0);
+
+ if (!accumulator) {
+ // This is the first reduction.
+ accumulator = result;
+ } else {
+ // On subsequent reduction, combine with the accumulator.
+ accumulator = makeArithReduction(rewriter, loc, reductionOp.kind(),
+ accumulator, result);
+ }
+ }
+
+ rewriter.replaceOp(reductionOp, accumulator);
+ return success();
+ }
+
+private:
+ const vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options) {
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
- UnrollMultiReductionPattern>(patterns.getContext(), options);
+ UnrollReductionPattern, UnrollMultiReductionPattern>(
+ patterns.getContext(), options);
}
void mlir::vector::populatePropagateVectorDistributionPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index dd1a6fd781e47..5a0014451b2c4 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -106,3 +106,23 @@ func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK: return %[[V2]] : vector<4xf32>
+
+// CHECK-LABEL: func @vector_reduction(
+// CHECK-SAME: %[[v:.*]]: vector<8xf32>
+// CHECK: %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2]
+// CHECK: %[[r0:.*]] = vector.reduction <add>, %[[s0]]
+// CHECK: %[[s1:.*]] = vector.extract_strided_slice %[[v]] {offsets = [2], sizes = [2]
+// CHECK: %[[r1:.*]] = vector.reduction <add>, %[[s1]]
+// CHECK: %[[add1:.*]] = arith.addf %[[r0]], %[[r1]]
+// CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[v]] {offsets = [4], sizes = [2]
+// CHECK: %[[r2:.*]] = vector.reduction <add>, %[[s2]]
+// CHECK: %[[add2:.*]] = arith.addf %[[add1]], %[[r2]]
+// CHECK: %[[s3:.*]] = vector.extract_strided_slice %[[v]] {offsets = [6], sizes = [2]
+// CHECK: %[[r3:.*]] = vector.reduction <add>, %[[s3]]
+// CHECK: %[[add3:.*]] = arith.addf %[[add2]], %[[r3]]
+// CHECK: return %[[add3]]
+func @vector_reduction(%v : vector<8xf32>) -> f32 {
+ %0 = vector.reduction <add>, %v : vector<8xf32> into f32
+ return %0 : f32
+}
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 2bf5e3f1a8e7d..f139e3cdcd68e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -268,6 +268,12 @@ struct TestVectorUnrollingPatterns
return success(isa<arith::AddFOp, vector::FMAOp,
vector::MultiDimReductionOp>(op));
}));
+ populateVectorUnrollPatterns(
+ patterns, UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{2})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::ReductionOp>(op));
+ }));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
More information about the Mlir-commits
mailing list