[Mlir-commits] [mlir] [mlir][vector] Lower vector.fma to llvm.intr.fma (PR #143089)
Jianjian Guan
llvmlistbot at llvm.org
Fri Jun 6 01:07:22 PDT 2025
https://github.com/jacquesguan created https://github.com/llvm/llvm-project/pull/143089
Now it lowers to llvm.intr.fmuladd.
>From eed9292a8aecaff679b0bdbcac8ae6bdc56ce015 Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Fri, 6 Jun 2025 16:02:50 +0800
Subject: [PATCH] [mlir][vector] Lower vector.fma to llvm.intr.fma
Now it lowers to llvm.intr.fmuladd.
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 6 ++--
.../VectorToLLVM/vector-to-llvm.mlir | 30 +++++++++----------
mlir/test/Dialect/LLVM/transform-e2e.mlir | 2 +-
3 files changed, 19 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f725993635672..26a2de43e2836 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1208,7 +1208,7 @@ class VectorExtractOpConversion
};
/// Conversion pattern that turns a vector.fma on a 1-D vector
-/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
+/// into an llvm.intr.fma. This is a trivial 1-1 conversion.
/// This does not match vectors of n >= 2 rank.
///
/// Example:
@@ -1217,7 +1217,7 @@ class VectorExtractOpConversion
/// ```
/// is converted to:
/// ```
-/// llvm.intr.fmuladd %va, %va, %va:
+/// llvm.intr.fma %va, %va, %va:
/// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
/// -> !llvm."<8 x f32>">
/// ```
@@ -1232,7 +1232,7 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
if (vType.getRank() > 1)
return failure();
- rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
+ rewriter.replaceOpWithNewOp<LLVM::FMAOp>(
fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 64e51f5554628..9d48a3a93e1bb 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -660,14 +660,14 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v
// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
// CHECK: %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[T9:.*]] = llvm.intr.fmuladd(%[[T6]], %[[B]], %[[T8]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[T9:.*]] = llvm.intr.fma(%[[T6]], %[[B]], %[[T8]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
// CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][0] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T12:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[T13:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T12]] : i64] : vector<2xf32>
// CHECK: %[[T14Insert:.*]] = llvm.insertelement %[[T13]]
// CHECK: %[[T14:.*]] = llvm.shufflevector %[[T14Insert]]
// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[T17:.*]] = llvm.intr.fmuladd(%[[T14]], %[[B]], %[[T16]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[T17:.*]] = llvm.intr.fma(%[[T14]], %[[B]], %[[T16]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T11]][1] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
// CHECK: return %[[T19]] : vector<2x3xf32>
@@ -690,14 +690,14 @@ func.func @outerproduct_add_scalable(%arg0: vector<2xf32>, %arg1: vector<[3]xf32
// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
// CHECK: %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[T9:.*]] = llvm.intr.fmuladd(%[[T6]], %[[B]], %[[T8]]) : (vector<[3]xf32>, vector<[3]xf32>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[T9:.*]] = llvm.intr.fma(%[[T6]], %[[B]], %[[T8]]) : (vector<[3]xf32>, vector<[3]xf32>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][0] : !llvm.array<2 x vector<[3]xf32>>
// CHECK: %[[T12:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[T13:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T12]] : i64] : vector<2xf32>
// CHECK: %[[T14Insert:.*]] = llvm.insertelement %[[T13]]
// CHECK: %[[T14:.*]] = llvm.shufflevector %[[T14Insert]]
// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[T17:.*]] = llvm.intr.fmuladd(%[[T14]], %[[B]], %[[T16]]) : (vector<[3]xf32>, vector<[3]xf32>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[T17:.*]] = llvm.intr.fma(%[[T14]], %[[B]], %[[T16]]) : (vector<[3]xf32>, vector<[3]xf32>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T11]][1] : !llvm.array<2 x vector<[3]xf32>>
// CHECK: %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<[3]xf32>> to vector<2x[3]xf32>
// CHECK: return %[[T19]] : vector<2x[3]xf32>
@@ -715,7 +715,7 @@ func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
// CHECK-LABEL: func.func @masked_float_add_outerprod(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
-// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32>
+// CHECK: %[[VAL_8:.*]] = llvm.intr.fma(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32>
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
// -----
@@ -727,7 +727,7 @@ func.func @masked_float_add_outerprod_scalable(%arg0: vector<[2]xf32>, %arg1: f3
// CHECK-LABEL: func.func @masked_float_add_outerprod_scalable(
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<[2]xf32>, %[[VAL_3:.*]]: vector<[2]xi1>) -> vector<[2]xf32> {
-// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<[2]xf32>, vector<[2]xf32>, vector<[2]xf32>) -> vector<[2]xf32>
+// CHECK: %[[VAL_8:.*]] = llvm.intr.fma(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<[2]xf32>, vector<[2]xf32>, vector<[2]xf32>) -> vector<[2]xf32>
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<[2]xi1>, vector<[2]xf32>
// -----
@@ -1363,29 +1363,29 @@ func.func @fma(%vec_1d: vector<8xf32>, %vec_2d: vector<2x4xf32>, %vec_3d: vector
// CHECK-SAME: %[[VEC_2D:.*]]: vector<2x4xf32>
// CHECK-SAME: %[[VEC_3D:.*]]: vector<1x1x1xf32>
// CHECK: %[[VEC_2D_CAST:.*]] = builtin.unrealized_conversion_cast %[[VEC_2D]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
- // CHECK: llvm.intr.fmuladd
+ // CHECK: llvm.intr.fma
// CHECK-SAME: (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
%0 = vector.fma %vec_1d, %vec_1d, %vec_1d : vector<8xf32>
// CHECK: %[[VEC_2D_00:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<4xf32>>
// CHECK: %[[VEC_2D_01:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<4xf32>>
// CHECK: %[[VEC_2D_02:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<4xf32>>
- // CHECK: %[[VEC_2D_ADD_1:.*]] = llvm.intr.fmuladd(%[[VEC_2D_00]], %[[VEC_2D_01]], %[[VEC_2D_02]]) :
+ // CHECK: %[[VEC_2D_ADD_1:.*]] = llvm.intr.fma(%[[VEC_2D_00]], %[[VEC_2D_01]], %[[VEC_2D_02]]) :
// CHECK-SAME: (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32>
// CHECK: llvm.insertvalue %[[VEC_2D_ADD_1]], {{.*}}[0] : !llvm.array<2 x vector<4xf32>>
// CHECK: %[[VEC_2D_10:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<4xf32>>
// CHECK: %[[VEC_2D_11:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<4xf32>>
// CHECK: %[[VEC_2D_12:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<4xf32>>
- // CHECK: %[[VEC_2D_ADD_2:.*]] = llvm.intr.fmuladd(%[[VEC_2D_10]], %[[VEC_2D_11]], %[[VEC_2D_12]]) :
+ // CHECK: %[[VEC_2D_ADD_2:.*]] = llvm.intr.fma(%[[VEC_2D_10]], %[[VEC_2D_11]], %[[VEC_2D_12]]) :
// CHECK-SAME: (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32>
// CHECK: llvm.insertvalue %[[VEC_2D_ADD_2]], {{.*}}[1] : !llvm.array<2 x vector<4xf32>>
%1 = vector.fma %vec_2d, %vec_2d, %vec_2d : vector<2x4xf32>
- // CHECK: %[[C0:.*]] = llvm.intr.fmuladd
+ // CHECK: %[[C0:.*]] = llvm.intr.fma
// CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
%2 = vector.fma %vec_3d, %vec_3d, %vec_3d : vector<1x1x1xf32>
- // CHECK: %[[D0:.*]] = llvm.intr.fmuladd
+ // CHECK: %[[D0:.*]] = llvm.intr.fma
// CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
%3 = vector.fma %vec_0d, %vec_0d, %vec_0d : vector<f32>
@@ -1400,25 +1400,25 @@ func.func @fma_scalable(%vec_1d: vector<[8]xf32>, %vec_2d: vector<2x[4]xf32>, %v
// CHECK-SAME: %[[VEC_2D:.*]]: vector<2x[4]xf32>
// CHECK-SAME: %[[VEC_3D:.*]]: vector<1x1x[1]xf32>
// CHECK: %[[VEC_2D_CAST:.*]] = builtin.unrealized_conversion_cast %[[VEC_2D]] : vector<2x[4]xf32> to !llvm.array<2 x vector<[4]xf32>>
- // CHECK: llvm.intr.fmuladd
+ // CHECK: llvm.intr.fma
// CHECK-SAME: (vector<[8]xf32>, vector<[8]xf32>, vector<[8]xf32>) -> vector<[8]xf32>
%0 = vector.fma %vec_1d, %vec_1d, %vec_1d : vector<[8]xf32>
// CHECK: %[[VEC_2D_00:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<[4]xf32>>
// CHECK: %[[VEC_2D_01:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<[4]xf32>>
// CHECK: %[[VEC_2D_02:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][0] : !llvm.array<2 x vector<[4]xf32>>
- // CHECK: %[[VEC_2D_ADD_1:.*]] = llvm.intr.fmuladd(%[[VEC_2D_00]], %[[VEC_2D_01]], %[[VEC_2D_02]]) :
+ // CHECK: %[[VEC_2D_ADD_1:.*]] = llvm.intr.fma(%[[VEC_2D_00]], %[[VEC_2D_01]], %[[VEC_2D_02]]) :
// CHECK-SAME: (vector<[4]xf32>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
// CHECK: llvm.insertvalue %[[VEC_2D_ADD_1]], {{.*}}[0] : !llvm.array<2 x vector<[4]xf32>>
// CHECK: %[[VEC_2D_10:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<[4]xf32>>
// CHECK: %[[VEC_2D_11:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<[4]xf32>>
// CHECK: %[[VEC_2D_12:.*]] = llvm.extractvalue %[[VEC_2D_CAST]][1] : !llvm.array<2 x vector<[4]xf32>>
- // CHECK: %[[VEC_2D_ADD_2:.*]] = llvm.intr.fmuladd(%[[VEC_2D_10]], %[[VEC_2D_11]], %[[VEC_2D_12]]) :
+ // CHECK: %[[VEC_2D_ADD_2:.*]] = llvm.intr.fma(%[[VEC_2D_10]], %[[VEC_2D_11]], %[[VEC_2D_12]]) :
// CHECK-SAME: (vector<[4]xf32>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
// CHECK: llvm.insertvalue %[[VEC_2D_ADD_2]], {{.*}}[1] : !llvm.array<2 x vector<[4]xf32>>
%1 = vector.fma %vec_2d, %vec_2d, %vec_2d : vector<2x[4]xf32>
- // CHECK: %[[C0:.*]] = llvm.intr.fmuladd
+ // CHECK: %[[C0:.*]] = llvm.intr.fma
// CHECK-SAME: (vector<[1]xf32>, vector<[1]xf32>, vector<[1]xf32>) -> vector<[1]xf32>
%2 = vector.fma %vec_3d, %vec_3d, %vec_3d : vector<1x1x[1]xf32>
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e9..fc87d66e9d2c8 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -5,7 +5,7 @@ func.func @matmul_tensors(
%arg0: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, %arg2: tensor<2x6xf32>)
-> tensor<2x6xf32> {
// CHECK-NOT: linalg
-// CHECK: llvm.intr.fmuladd{{.*}}
+// CHECK: llvm.intr.fma{{.*}}
%0 = linalg.matmul ins(%arg0, %arg1: tensor<2x4xf32>, tensor<4x6xf32>)
outs(%arg2: tensor<2x6xf32>)
-> tensor<2x6xf32>
More information about the Mlir-commits
mailing list