[Mlir-commits] [mlir] ab45a43 - [mlir][Vector] Support 0-D vectors in FMAOp

Nicolas Vasilache llvmlistbot at llvm.org
Wed Aug 24 08:50:07 PDT 2022


Author: Michal Terepeta
Date: 2022-08-24T08:49:58-07:00
New Revision: ab45a4329bbf6f80a14d16dcdeef79beff6653ec

URL: https://github.com/llvm/llvm-project/commit/ab45a4329bbf6f80a14d16dcdeef79beff6653ec
DIFF: https://github.com/llvm/llvm-project/commit/ab45a4329bbf6f80a14d16dcdeef79beff6653ec.diff

LOG: [mlir][Vector] Support 0-D vectors in FMAOp

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D115742

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index aeffe601324e0..033c29bca60f6 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -650,8 +650,10 @@ def Vector_FMAOp :
        NoSideEffect, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
        DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
      ] # ElementwiseMappable.traits>,
-    Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>,
-    Results<(outs AnyVector:$result)> {
+    Arguments<(ins AnyVectorOfAnyRank:$lhs,
+                   AnyVectorOfAnyRank:$rhs,
+                   AnyVectorOfAnyRank:$acc)>,
+    Results<(outs AnyVectorOfAnyRank:$result)> {
   let summary = "vector fused multiply-add";
   let description = [{
     Multiply-add expressions operate on n-D vectors and compute a fused

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c3a18ae29ab49..8804f971ad5d2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -746,7 +746,7 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     VectorType vType = fmaOp.getVectorType();
-    if (vType.getRank() != 1)
+    if (vType.getRank() > 1)
       return failure();
     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 37827bd1c226c..40e4022461d04 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1084,7 +1084,7 @@ func.func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf3
 
 // -----
 
-func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>) {
+func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector<f32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>) {
   // CHECK-LABEL: @vector_fma
   //  CHECK-SAME: %[[A:.*]]: vector<8xf32>
   //  CHECK-SAME: %[[B:.*]]: vector<2x4xf32>
@@ -1112,7 +1112,11 @@ func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf
   //  CHECK-SAME:   (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
   %2 = vector.fma %c, %c, %c : vector<1x1x1xf32>
 
-  return %0, %1, %2: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>
+  //       CHECK: %[[D0:.*]] = "llvm.intr.fmuladd"
+  //  CHECK-SAME:   (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
+  %3 = vector.fma %d, %d, %d : vector<f32>
+
+  return %0, %1, %2, %3: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index ab8d1e07d99c7..4c3e322131715 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -478,11 +478,13 @@ func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
 }
 
 // CHECK-LABEL: @vector_fma
-func.func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) {
+func.func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>, %c: vector<f32>) {
   // CHECK: vector.fma %{{.*}} : vector<8xf32>
   vector.fma %a, %a, %a : vector<8xf32>
   // CHECK: vector.fma %{{.*}} : vector<8x4xf32>
   vector.fma %b, %b, %b : vector<8x4xf32>
+  // CHECK: vector.fma %{{.*}} : vector<f32>
+  vector.fma %c, %c, %c : vector<f32>
   return
 }
 

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 95e52192fa8c9..a29ab10ddbd2c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -105,9 +105,19 @@ func.func @create_mask_0d(%zero : index, %one : index) {
   return
 }
 
-func.func @reduce_add(%arg0: vector<f32>) -> f32 {
+func.func @reduce_add(%arg0: vector<f32>) {
   %0 = vector.reduction <add>, %arg0 : vector<f32> into f32    
-  return %0 : f32
+  vector.print %0 : f32
+  // CHECK: 5
+  return
+}
+
+func.func @fma_0d(%four: vector<f32>) {
+  %0 = vector.fma %four, %four, %four : vector<f32>
+  // 4 * 4 + 4 = 20
+  // CHECK: ( 20 )
+  vector.print %0: vector<f32>
+  return
 }
 
 func.func @entry() {
@@ -137,9 +147,10 @@ func.func @entry() {
   call  @create_mask_0d(%zero_idx, %one_idx) : (index, index) -> ()
 
   %red_array = arith.constant dense<5.0> : vector<f32>
-  %red_res = call  @reduce_add(%red_array) : (vector<f32>) -> (f32)  
-  vector.print %red_res : f32
-  // CHECK: 5
+  call  @reduce_add(%red_array) : (vector<f32>) -> ()
+
+  %5 = arith.constant dense<4.0> : vector<f32>
+  call  @fma_0d(%5) : (vector<f32>) -> ()
 
   return
 }


        


More information about the Mlir-commits mailing list