[Mlir-commits] [mlir] d8c7f44 - [mlir][vector] Add support for unrolling vector.fma
Lei Zhang
llvmlistbot at llvm.org
Tue Feb 16 07:00:06 PST 2021
Author: Lei Zhang
Date: 2021-02-16T09:56:25-05:00
New Revision: d8c7f442eaf21a5ad42a5ac101f66b69984ef065
URL: https://github.com/llvm/llvm-project/commit/d8c7f442eaf21a5ad42a5ac101f66b69984ef065
DIFF: https://github.com/llvm/llvm-project/commit/d8c7f442eaf21a5ad42a5ac101f66b69984ef065.diff
LOG: [mlir][vector] Add support for unrolling vector.fma
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D96781
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/vector-unroll-options.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index a7c12231a91f..cf18cd89e170 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -583,8 +583,10 @@ def Vector_ExtractMapOp :
}
def Vector_FMAOp :
- Op<Vector_Dialect, "fma", [NoSideEffect,
- AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
+ Op<Vector_Dialect, "fma", [
+ NoSideEffect, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
+ ]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>,
Results<(outs AnyVector:$result)> {
let summary = "vector fused multiply-add";
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 671cd865b1c3..af884f9d6ce6 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1258,6 +1258,14 @@ AffineMap calculateImplicitMap(MapOp op) {
AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
+//===----------------------------------------------------------------------===//
+// FmaOp
+//===----------------------------------------------------------------------===//
+
+Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getVectorType().getShape());
+}
+
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
@@ -2456,8 +2464,7 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
}
Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
- auto s = getVectorType().getShape();
- return SmallVector<int64_t, 4>{s.begin(), s.end()};
+ return llvm::to_vector<4>(getVectorType().getShape());
}
void TransferReadOp::getEffects(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 705d4ab65739..581039c48cb7 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -73,3 +73,10 @@ func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
// CHECK: vector.contract {
// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
// CHECK: return
+
+func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf32>) -> vector<4x4xf32> {
+ %0 = vector.fma %a, %b, %c: vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+// CHECK-LABEL: func @vector_fma
+// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 61b17178ef59..8ec970f68b23 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -151,8 +151,9 @@ struct TestVectorUnrollingPatterns
patterns.insert<UnrollVectorPattern>(
ctx, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
- .setFilterConstraint(
- [](Operation *op) { return success(isa<AddFOp>(op)); }));
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<AddFOp, vector::FMAOp>(op));
+ }));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
More information about the Mlir-commits
mailing list