[Mlir-commits] [mlir] ab45a43 - [mlir][Vector] Support 0-D vectors in FMAOp
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Aug 24 08:50:07 PDT 2022
Author: Michal Terepeta
Date: 2022-08-24T08:49:58-07:00
New Revision: ab45a4329bbf6f80a14d16dcdeef79beff6653ec
URL: https://github.com/llvm/llvm-project/commit/ab45a4329bbf6f80a14d16dcdeef79beff6653ec
DIFF: https://github.com/llvm/llvm-project/commit/ab45a4329bbf6f80a14d16dcdeef79beff6653ec.diff
LOG: [mlir][Vector] Support 0-D vectors in FMAOp
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D115742
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index aeffe601324e0..033c29bca60f6 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -650,8 +650,10 @@ def Vector_FMAOp :
NoSideEffect, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
] # ElementwiseMappable.traits>,
- Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>,
- Results<(outs AnyVector:$result)> {
+ Arguments<(ins AnyVectorOfAnyRank:$lhs,
+ AnyVectorOfAnyRank:$rhs,
+ AnyVectorOfAnyRank:$acc)>,
+ Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "vector fused multiply-add";
let description = [{
Multiply-add expressions operate on n-D vectors and compute a fused
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c3a18ae29ab49..8804f971ad5d2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -746,7 +746,7 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType vType = fmaOp.getVectorType();
- if (vType.getRank() != 1)
+ if (vType.getRank() > 1)
return failure();
rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 37827bd1c226c..40e4022461d04 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1084,7 +1084,7 @@ func.func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf3
// -----
-func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>) {
+func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector<f32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>) {
// CHECK-LABEL: @vector_fma
// CHECK-SAME: %[[A:.*]]: vector<8xf32>
// CHECK-SAME: %[[B:.*]]: vector<2x4xf32>
@@ -1112,7 +1112,11 @@ func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf
// CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
%2 = vector.fma %c, %c, %c : vector<1x1x1xf32>
- return %0, %1, %2: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>
+ // CHECK: %[[D0:.*]] = "llvm.intr.fmuladd"
+ // CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
+ %3 = vector.fma %d, %d, %d : vector<f32>
+
+ return %0, %1, %2, %3: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index ab8d1e07d99c7..4c3e322131715 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -478,11 +478,13 @@ func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
}
// CHECK-LABEL: @vector_fma
-func.func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) {
+func.func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>, %c: vector<f32>) {
// CHECK: vector.fma %{{.*}} : vector<8xf32>
vector.fma %a, %a, %a : vector<8xf32>
// CHECK: vector.fma %{{.*}} : vector<8x4xf32>
vector.fma %b, %b, %b : vector<8x4xf32>
+ // CHECK: vector.fma %{{.*}} : vector<f32>
+ vector.fma %c, %c, %c : vector<f32>
return
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 95e52192fa8c9..a29ab10ddbd2c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -105,9 +105,19 @@ func.func @create_mask_0d(%zero : index, %one : index) {
return
}
-func.func @reduce_add(%arg0: vector<f32>) -> f32 {
+func.func @reduce_add(%arg0: vector<f32>) {
%0 = vector.reduction <add>, %arg0 : vector<f32> into f32
- return %0 : f32
+ vector.print %0 : f32
+ // CHECK: 5
+ return
+}
+
+func.func @fma_0d(%four: vector<f32>) {
+ %0 = vector.fma %four, %four, %four : vector<f32>
+ // 4 * 4 + 4 = 20
+ // CHECK: ( 20 )
+ vector.print %0: vector<f32>
+ return
}
func.func @entry() {
@@ -137,9 +147,10 @@ func.func @entry() {
call @create_mask_0d(%zero_idx, %one_idx) : (index, index) -> ()
%red_array = arith.constant dense<5.0> : vector<f32>
- %red_res = call @reduce_add(%red_array) : (vector<f32>) -> (f32)
- vector.print %red_res : f32
- // CHECK: 5
+ call @reduce_add(%red_array) : (vector<f32>) -> ()
+
+ %5 = arith.constant dense<4.0> : vector<f32>
+ call @fma_0d(%5) : (vector<f32>) -> ()
return
}
More information about the Mlir-commits
mailing list