[Mlir-commits] [mlir] [mlir][vector] Lower vector.fma to llvm.intr.fma (PR #143089)

Jianjian Guan llvmlistbot at llvm.org
Fri Jun 6 01:07:22 PDT 2025


https://github.com/jacquesguan created https://github.com/llvm/llvm-project/pull/143089

Now it lowers to llvm.intr.fmuladd.

>From eed9292a8aecaff679b0bdbcac8ae6bdc56ce015 Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Fri, 6 Jun 2025 16:02:50 +0800
Subject: [PATCH] [mlir][vector] Lower vector.fma to llvm.intr.fma

Now it lowers to llvm.intr.fmuladd.
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  6 ++--
 .../VectorToLLVM/vector-to-llvm.mlir          | 30 +++++++++----------
 mlir/test/Dialect/LLVM/transform-e2e.mlir     |  2 +-
 3 files changed, 19 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f725993635672..26a2de43e2836 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1208,7 +1208,7 @@ class VectorExtractOpConversion
 };
 
 /// Conversion pattern that turns a vector.fma on a 1-D vector
-/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
+/// into an llvm.intr.fma. This is a trivial 1-1 conversion.
 /// This does not match vectors of n >= 2 rank.
 ///
 /// Example:
@@ -1217,7 +1217,7 @@ class VectorExtractOpConversion
 /// ```
 /// is converted to:
 /// ```
-///  llvm.intr.fmuladd %va, %va, %va:
+///  llvm.intr.fma %va, %va, %va:
 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
 ///    -> !llvm."<8 x f32>">
 /// ```
@@ -1232,7 +1232,7 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
     if (vType.getRank() > 1)
       return failure();
 
-    rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
+    rewriter.replaceOpWithNewOp<LLVM::FMAOp>(
         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
     return success();
   }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 64e51f5554628..9d48a3a93e1bb 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -660,14 +660,14 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v
 // CHECK:       %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
 // CHECK:       %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
 // CHECK:       %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<3xf32>>
-// CHECK:       %[[T9:.*]] = llvm.intr.fmuladd(%[[T6]], %[[B]], %[[T8]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
+// CHECK:       %[[T9:.*]] = llvm.intr.fma(%[[T6]], %[[B]], %[[T8]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
 // CHECK:       %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][0] : !llvm.array<2 x vector<3xf32>>
 // CHECK:       %[[T12:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:       %[[T13:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T12]] : i64] : vector<2xf32>
 // CHECK:       %[[T14Insert:.*]] = llvm.insertelement %[[T13]]
 // CHECK:       %[[T14:.*]] = llvm.shufflevector %[[T14Insert]]
 // CHECK:       %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<3xf32>>
-// CHECK:       %[[T17:.*]] = llvm.intr.fmuladd(%[[T14]], %[[B]], %[[T16]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
+// CHECK:       %[[T17:.*]] = llvm.intr.fma(%[[T14]], %[[B]], %[[T16]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
 // CHECK:       %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T11]][1] : !llvm.array<2 x vector<3xf32>>
 // CHECK:       %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
 // CHECK:       return %[[T19]] : vector<2x3xf32>
@@ -690,14 +690,14 @@ func.func @outerproduct_add_scalable(%arg0: vector<2xf32>, %arg1: vector<[3]xf32
 // CHECK:       %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
 // CHECK:       %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
 // CHECK:       %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK:       %[[T9:.*]] = llvm.intr.fmuladd(%[[T6]], %[[B]], %[[T8]]) : (vector<[3]xf32>, vector<[3]xf32>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK:       %[[T9:.*]] = llvm.intr.fma(%[[T6]], %[[B]], %[[T8]]) : (vector<[3]xf32>, vector<[3]xf32>, vector<[3]xf32>) -> vector<[3]xf32>
 // CHECK:       %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][0] : !llvm.array<2 x vector<[3]xf32>>
 // CHECK:       %[[T12:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:       %[[T13:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T12]] : i64] : vector<2xf32>
 // CHECK:       %[[T14Insert:.*]] = llvm.insertelement %[[T13]]
 // CHECK:       %[[T14:.*]] = llvm.shufflevector %[[T14Insert]]
 // CHECK:       %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK:       %[[T17:.*]] = llvm.intr.fmuladd(%[[T14]], %[[B]], %[[T16]]) : (vector<[3]xf32>, vector<[3]xf32>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK:       %[[T17:.*]] = llvm.intr.fma(%[[T14]], %[[B]], %[[T16]]) : (vector<[3]xf32>, vector<[3]xf32>, vector<[3]xf32>) -> vector<[3]xf32>
 // CHECK:       %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T11]][1] : !llvm.array<2 x vector<[3]xf32>>
 // CHECK:       %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<[3]xf32>> to vector<2x[3]xf32>
 // CHECK:       return %[[T19]] : vector<2x[3]xf32>
@@ -715,7 +715,7 @@ func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
 
 // CHECK-LABEL:   func.func @masked_float_add_outerprod(
 // CHECK-SAME:                                          %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
-// CHECK:           %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]])  : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32>
+// CHECK:           %[[VAL_8:.*]] = llvm.intr.fma(%[[VAL_0]], %{{.*}}, %[[VAL_2]])  : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32>
 // CHECK:           %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
 
 // -----
@@ -727,7 +727,7 @@ func.func @masked_float_add_outerprod_scalable(%arg0: vector<[2]xf32>, %arg1: f3
 
 // CHECK-LABEL:   func.func @masked_float_add_outerprod_scalable(
 // CHECK-SAME:                                                   %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> {
-// CHECK:           %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]])  : (vector<[2]xf32>, vector<[2]xf32>, vector<[2]xf32>) -> vector<[2]xf32>
+// CHECK:           %[[VAL_8:.*]] = llvm.intr.fma(%[[VAL_0]], %{{.*}}, %[[VAL_2]])  : (vector<[2]xf32>, vector<[2]xf32>, vector<[2]xf32>) -> vector<[2]xf32>
 // CHECK:           %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xf32>
 
 // -----
@@ -1363,29 +1363,29 @@ func.func @fma(%vec_1d: vector<8xf32>, %vec_2d: vector<2x4xf32>, %vec_3d: vector
   //  CHECK-SAME: %[[VEC_2D:.*]]: vector<2x4xf32>
   //  CHECK-SAME: %[[VEC_3D:.*]]: vector<1x1x1xf32>
   //       CHECK: %[[VEC_2D_CAST:.*]] = builtin.unrealized_conversion_cast %[[VEC_2D]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
-  //       CHECK: llvm.intr.fmuladd
+  //       CHECK: llvm.intr.fma
   //  CHECK-SAME:   (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
   %0 = vector.fma %vec_1d, %vec_1d, %vec_1d : vector<8xf32>
 
   //       CHECK: %[[VEC_2D_00:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<4xf32>>
   //       CHECK: %[[VEC_2D_01:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<4xf32>>
   //       CHECK: %[[VEC_2D_02:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<4xf32>>
-  //       CHECK: %[[VEC_2D_ADD_1:.*]] = llvm.intr.fmuladd(%[[VEC_2D_00]], %[[VEC_2D_01]], %[[VEC_2D_02]]) :
+  //       CHECK: %[[VEC_2D_ADD_1:.*]] = llvm.intr.fma(%[[VEC_2D_00]], %[[VEC_2D_01]], %[[VEC_2D_02]]) :
   //  CHECK-SAME: (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32>
   //       CHECK: llvm.insertvalue %[[VEC_2D_ADD_1]], {{.*}}[0] : !llvm.array<2 x vector<4xf32>>
   //       CHECK: %[[VEC_2D_10:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<4xf32>>
   //       CHECK: %[[VEC_2D_11:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<4xf32>>
   //       CHECK: %[[VEC_2D_12:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<4xf32>>
-  //       CHECK: %[[VEC_2D_ADD_2:.*]] = llvm.intr.fmuladd(%[[VEC_2D_10]], %[[VEC_2D_11]], %[[VEC_2D_12]]) :
+  //       CHECK: %[[VEC_2D_ADD_2:.*]] = llvm.intr.fma(%[[VEC_2D_10]], %[[VEC_2D_11]], %[[VEC_2D_12]]) :
   //  CHECK-SAME: (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32>
   //       CHECK: llvm.insertvalue %[[VEC_2D_ADD_2]], {{.*}}[1] : !llvm.array<2 x vector<4xf32>>
   %1 = vector.fma %vec_2d, %vec_2d, %vec_2d : vector<2x4xf32>
 
-  //       CHECK: %[[C0:.*]] = llvm.intr.fmuladd
+  //       CHECK: %[[C0:.*]] = llvm.intr.fma
   //  CHECK-SAME:   (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
   %2 = vector.fma %vec_3d, %vec_3d, %vec_3d : vector<1x1x1xf32>
 
-  //       CHECK: %[[D0:.*]] = llvm.intr.fmuladd
+  //       CHECK: %[[D0:.*]] = llvm.intr.fma
   //  CHECK-SAME:   (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
   %3 = vector.fma %vec_0d, %vec_0d, %vec_0d : vector<f32>
 
@@ -1400,25 +1400,25 @@ func.func @fma_scalable(%vec_1d: vector<[8]xf32>, %vec_2d: vector<2x[4]xf32>, %v
   //  CHECK-SAME: %[[VEC_2D:.*]]: vector<2x[4]xf32>
   //  CHECK-SAME: %[[VEC_3D:.*]]: vector<1x1x[1]xf32>
   //       CHECK: %[[VEC_2D_CAST:.*]] = builtin.unrealized_conversion_cast %[[VEC_2D]] : vector<2x[4]xf32> to !llvm.array<2 x vector<[4]xf32>>
-  //       CHECK: llvm.intr.fmuladd
+  //       CHECK: llvm.intr.fma
   //  CHECK-SAME:   (vector<[8]xf32>, vector<[8]xf32>, vector<[8]xf32>) -> vector<[8]xf32>
   %0 = vector.fma %vec_1d, %vec_1d, %vec_1d : vector<[8]xf32>
 
   //       CHECK: %[[VEC_2D_00:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<[4]xf32>>
   //       CHECK: %[[VEC_2D_01:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<[4]xf32>>
   //       CHECK: %[[VEC_2D_02:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<[4]xf32>>
-  //       CHECK: %[[VEC_2D_ADD_1:.*]] = llvm.intr.fmuladd(%[[VEC_2D_00]], %[[VEC_2D_01]], %[[VEC_2D_02]]) :
+  //       CHECK: %[[VEC_2D_ADD_1:.*]] = llvm.intr.fma(%[[VEC_2D_00]], %[[VEC_2D_01]], %[[VEC_2D_02]]) :
   //  CHECK-SAME: (vector<[4]xf32>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
   //       CHECK: llvm.insertvalue %[[VEC_2D_ADD_1]], {{.*}}[0] : !llvm.array<2 x vector<[4]xf32>>
   //       CHECK: %[[VEC_2D_10:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<[4]xf32>>
   //       CHECK: %[[VEC_2D_11:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<[4]xf32>>
   //       CHECK: %[[VEC_2D_12:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<[4]xf32>>
-  //       CHECK: %[[VEC_2D_ADD_2:.*]] = llvm.intr.fmuladd(%[[VEC_2D_10]], %[[VEC_2D_11]], %[[VEC_2D_12]]) :
+  //       CHECK: %[[VEC_2D_ADD_2:.*]] = llvm.intr.fma(%[[VEC_2D_10]], %[[VEC_2D_11]], %[[VEC_2D_12]]) :
   //  CHECK-SAME: (vector<[4]xf32>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
   //       CHECK: llvm.insertvalue %[[VEC_2D_ADD_2]], {{.*}}[1] : !llvm.array<2 x vector<[4]xf32>>
   %1 = vector.fma %vec_2d, %vec_2d, %vec_2d : vector<2x[4]xf32>
 
-  //       CHECK: %[[C0:.*]] = llvm.intr.fmuladd
+  //       CHECK: %[[C0:.*]] = llvm.intr.fma
   //  CHECK-SAME:   (vector<[1]xf32>, vector<[1]xf32>, vector<[1]xf32>) -> vector<[1]xf32>
   %2 = vector.fma %vec_3d, %vec_3d, %vec_3d : vector<1x1x[1]xf32>
 
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e9..fc87d66e9d2c8 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -5,7 +5,7 @@ func.func @matmul_tensors(
   %arg0: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, %arg2: tensor<2x6xf32>)
     -> tensor<2x6xf32> {
 // CHECK-NOT: linalg
-// CHECK: llvm.intr.fmuladd{{.*}}
+// CHECK: llvm.intr.fma{{.*}}
   %0 = linalg.matmul  ins(%arg0, %arg1: tensor<2x4xf32>, tensor<4x6xf32>)
                      outs(%arg2: tensor<2x6xf32>)
     -> tensor<2x6xf32>



More information about the Mlir-commits mailing list