[Mlir-commits] [mlir] b6204b9 - [mlir][Vector] Introduce UnrollVectorOptions to control vector unrolling.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 23 13:52:46 PDT 2020
Author: MaheshRavishankar
Date: 2020-10-23T13:52:26-07:00
New Revision: b6204b995eaa2ec771f947a2109bd2ef338e688c
URL: https://github.com/llvm/llvm-project/commit/b6204b995eaa2ec771f947a2109bd2ef338e688c
DIFF: https://github.com/llvm/llvm-project/commit/b6204b995eaa2ec771f947a2109bd2ef338e688c.diff
LOG: [mlir][Vector] Introduce UnrollVectorOptions to control vector unrolling.
The current pattern for vector unrolling takes the native shape to
unroll to at pattern instantiation time, but the native shape might
defer based on the types of the operand. Introduce a
UnrollVectorOptions struct which allows for using a function that will
return the native shape based on the operation. Move other options of
unrolling like `filterConstraints` into this struct.
Differential Revision: https://reviews.llvm.org/D89744
Added:
mlir/test/Dialect/Vector/vector-unroll-options.mlir
Modified:
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 157084a2bff1..a1cf90cb10d5 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -85,21 +85,51 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
ArrayRef<int64_t> targetShape);
+/// Options that control the vector unrolling.
+struct UnrollVectorOptions {
+ using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
+ /// Callback function that indicates whether vector unrolling should be
+ /// attempted on the operation.
+ FilterConstraintFnType filterConstraint = nullptr;
+ UnrollVectorOptions &setFilterContraint(FilterConstraintFnType constraint) {
+ filterConstraint = constraint;
+ return *this;
+ }
+
+ using NativeShapeFnType =
+ std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
+ /// Function that returns the shape of the vector to unroll to for a given
+ /// operation. The unrolling is aborted if the function returns `llvm::None`.
+ NativeShapeFnType nativeShape = nullptr;
+ UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
+ nativeShape = fn;
+ return *this;
+ }
+
+ /// Set the native shape to use for unrolling.
+ UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
+ SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
+ nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
+ return tsShape;
+ };
+ return *this;
+ }
+};
/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
/// declaratively.
template <typename OpTy>
struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
using FilterConstraintType = std::function<LogicalResult(OpTy op)>;
- UnrollVectorPattern(
- ArrayRef<int64_t> targetShape, MLIRContext *context,
- FilterConstraintType constraint = [](OpTy op) { return success(); })
- : OpRewritePattern<OpTy>(context),
- targetShape(targetShape.begin(), targetShape.end()),
- filter(constraint) {}
+ UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options)
+ : OpRewritePattern<OpTy>(context), options(options) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (failed(filter(op)))
+ if (options.filterConstraint && failed(options.filterConstraint(op)))
return failure();
+ if (!options.nativeShape) {
+ return op.emitError("vector unrolling expects the native shape or native"
+ "shape call back function to be set");
+ }
auto unrollableVectorOp =
dyn_cast<VectorUnrollOpInterface>(op.getOperation());
if (!unrollableVectorOp)
@@ -107,19 +137,22 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
if (!maybeUnrollShape)
return failure();
- auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape);
+ Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
+ if (!targetShape)
+ return op.emitError("failed to get target shape for vector unroll");
+ auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
if (!maybeShapeRatio ||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
return failure();
if (std::is_same<OpTy, TransferWriteOp>::value) {
- if (failed(unrollTransferWriteOp(rewriter, op, targetShape)))
+ if (failed(unrollTransferWriteOp(rewriter, op, *targetShape)))
return failure();
rewriter.eraseOp(op);
return success();
}
if (op.getOperation()->getNumResults() != 1)
return failure();
- auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);
+ auto resultVector = unrollSingleResultVectorOp(rewriter, op, *targetShape);
if (resultVector.size() != 1)
return failure();
rewriter.replaceOp(op, resultVector.front());
@@ -127,8 +160,7 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
}
private:
- SmallVector<int64_t, 4> targetShape;
- FilterConstraintType filter;
+ UnrollVectorOptions options;
};
/// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
new file mode 100644
index 000000000000..705d4ab65739
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s
+
+func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>,
+ %init : vector<8x8xf32>) -> vector<8x8xf32> {
+ %0 = vector.contract
+ {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (j, k)>,
+ affine_map<(i, j, k) -> (i, j)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32>
+ return %0 : vector<8x8xf32>
+}
+// CHECK-LABEL: func @vector_contract_f32
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: return
+
+func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
+ %init : vector<8x8xf16>) -> vector<8x8xf16> {
+ %0 = vector.contract
+ {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (j, k)>,
+ affine_map<(i, j, k) -> (i, j)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ %lhs, %rhs, %init : vector<8x8xf16>, vector<8x8xf16> into vector<8x8xf16>
+ return %0 : vector<8x8xf16>
+}
+// CHECK-LABEL: func @vector_contract_f16
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: return
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 52d0f7b2bb5e..5369ab51ddb0 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -26,9 +27,10 @@ struct TestVectorToVectorConversion
void runOnFunction() override {
OwningRewritePatternList patterns;
auto *ctx = &getContext();
- patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
+ patterns.insert<UnrollVectorPattern<AddFOp>>(
+ ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
- ArrayRef<int64_t>{2, 2, 2}, ctx);
+ ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), patterns);
@@ -113,16 +115,44 @@ struct TestVectorContractionConversion
struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
+ TestVectorUnrollingPatterns() = default;
+ TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
- patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
- patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
- ArrayRef<int64_t>{2, 2, 2}, ctx);
+ patterns.insert<UnrollVectorPattern<AddFOp>>(
+ ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
+
+ if (unrollBasedOnType) {
+ UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
+ [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
+ vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
+ SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
+ if (auto floatType = contractOp.getLhsType()
+ .getElementType()
+ .dyn_cast<FloatType>()) {
+ if (floatType.getWidth() == 16) {
+ nativeShape[2] = 4;
+ }
+ }
+ return nativeShape;
+ };
+ patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
+ ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn));
+ } else {
+ patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
+ ctx,
+ UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
+ }
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
+
+ Option<bool> unrollBasedOnType{
+ *this, "unroll-based-on-type",
+ llvm::cl::desc("Set the unroll factor based on type of the operation"),
+ llvm::cl::init(false)};
};
struct TestVectorDistributePatterns
@@ -165,9 +195,9 @@ struct TestVectorTransferUnrollingPatterns
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
- ArrayRef<int64_t>{2, 2}, ctx);
+ ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
- ArrayRef<int64_t>{2, 2}, ctx);
+ ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), patterns);
More information about the Mlir-commits
mailing list