[Mlir-commits] [mlir] [MLIR][Vector] Add fastmath attribute to vector.contract (PR #192788)
Princeton Ferro
llvmlistbot at llvm.org
Sat Apr 18 06:48:14 PDT 2026
https://github.com/Prince781 updated https://github.com/llvm/llvm-project/pull/192788
>From a27947dcc2eeffd56a3719e5774527aa0e64993e Mon Sep 17 00:00:00 2001
From: Princeton Ferro <pferro at nvidia.com>
Date: Sat, 18 Apr 2026 06:47:51 -0700
Subject: [PATCH] fix comment: vector.reduction not vector.reduce
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 10 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 18 ++-
.../Vector/Transforms/LowerVectorContract.cpp | 48 ++++---
.../vector-contract-to-dot-transforms.mlir | 126 ++++++++++++++++++
4 files changed, 178 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 68ef49172e662..fdde3995f6333 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -52,6 +52,7 @@ def Vector_ContractionOp :
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
PredOpTrait<"third operand acc and result have same element type",
TCresVTEtIsSameAsOpBase<0, 2>>,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
@@ -59,7 +60,10 @@ def Vector_ContractionOp :
ArrayAttr:$indexing_maps,
Vector_IteratorTypeArrayAttr:$iterator_types,
DefaultValuedAttr<Vector_CombiningKindAttr,
- "CombiningKind::ADD">:$kind)>,
+ "CombiningKind::ADD">:$kind,
+ DefaultValuedAttr<
+ Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath)>,
Results<(outs AnyType)> {
let summary = "vector contraction operation";
let description = [{
@@ -180,7 +184,9 @@ def Vector_ContractionOp :
"ArrayRef<IteratorType>":$iteratorTypes)>,
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
"ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes,
- "CombiningKind":$kind)>
+ "CombiningKind":$kind,
+ CArg<"::mlir::arith::FastMathFlags",
+ "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
];
let extraClassDeclaration = [{
VectorType getLhsType() {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3d3e49134363f..2f48cdf2f026f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -818,13 +818,18 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
Value lhs, Value rhs, Value acc,
ArrayAttr indexingMaps,
- ArrayAttr iteratorTypes, CombiningKind kind) {
+ ArrayAttr iteratorTypes, CombiningKind kind,
+ arith::FastMathFlags fastMathFlags) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc.getType());
result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
result.addAttribute(getKindAttrName(result.name),
CombiningKindAttr::get(builder.getContext(), kind));
+ if (fastMathFlags != arith::FastMathFlags::none)
+ result.addAttribute(
+ getFastmathAttrName(result.name),
+ arith::FastMathFlagsAttr::get(builder.getContext(), fastMathFlags));
}
ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -921,8 +926,14 @@ void ContractionOp::print(OpAsmPrinter &p) {
attrs.emplace_back(getIteratorTypesAttrName(),
ArrayAttr::get(getContext(), iteratorTypeNames));
- } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
+ } else if (traitAttrsSet.count(attr.getName().strref()) > 0) {
+ // Omit fastmath when it equals the default (none) to keep output clean.
+ if (attr.getName() == getFastmathAttrName() &&
+ llvm::cast<arith::FastMathFlagsAttr>(attr.getValue()).getValue() ==
+ arith::FastMathFlags::none)
+ continue;
attrs.push_back(attr);
+ }
}
auto dictAttr = DictionaryAttr::get(getContext(), attrs);
@@ -1147,7 +1158,8 @@ Type ContractionOp::getExpectedMaskType() {
SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
return SmallVector<StringRef>{getIndexingMapsAttrName(),
- getIteratorTypesAttrName(), getKindAttrName()};
+ getIteratorTypesAttrName(), getKindAttrName(),
+ getFastmathAttrName()};
}
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 14fbdd2243676..eaf7bb8109514 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -123,7 +123,8 @@ static Value reshapeStore(Location loc, Value val, Value result,
static std::optional<Value>
createContractArithOp(Location loc, Value x, Value y, Value acc,
vector::CombiningKind kind, PatternRewriter &rewriter,
- bool isInt, Value mask = Value()) {
+ bool isInt, Value mask = Value(),
+ arith::FastMathFlagsAttr fmf = {}) {
using vector::CombiningKind;
Value mul;
@@ -150,14 +151,13 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
fma = selectPassthru(rewriter, mask, fma, acc);
return fma;
}
- mul = arith::MulFOp::create(rewriter, loc, x, y);
+ mul = arith::MulFOp::create(rewriter, loc, x, y, fmf);
}
if (!acc)
return std::optional<Value>(mul);
- return makeArithReduction(rewriter, loc, kind, mul, acc,
- /*fastmath=*/nullptr, mask);
+ return makeArithReduction(rewriter, loc, kind, mul, acc, fmf, mask);
}
/// Return the positions of the reductions in the given map.
@@ -184,19 +184,21 @@ static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
/// operands `x` and `y`.
static Value createAdd(Location loc, Value x, Value y, bool isInt,
- PatternRewriter &rewriter) {
+ PatternRewriter &rewriter,
+ arith::FastMathFlagsAttr fmf = {}) {
if (isInt)
return arith::AddIOp::create(rewriter, loc, x, y);
- return arith::AddFOp::create(rewriter, loc, x, y);
+ return arith::AddFOp::create(rewriter, loc, x, y, fmf);
}
/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
/// operands `x and `y`.
static Value createMul(Location loc, Value x, Value y, bool isInt,
- PatternRewriter &rewriter) {
+ PatternRewriter &rewriter,
+ arith::FastMathFlagsAttr fmf = {}) {
if (isInt)
return arith::MulIOp::create(rewriter, loc, x, y);
- return arith::MulFOp::create(rewriter, loc, x, y);
+ return arith::MulFOp::create(rewriter, loc, x, y, fmf);
}
namespace {
@@ -705,6 +707,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
Value res = arith::ConstantOp::create(rewriter, loc, dstType,
rewriter.getZeroAttr(dstType));
bool isInt = isa<IntegerType>(dstType.getElementType());
+ arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
llvm::SmallVector<Value> extractedCols;
extractedCols.reserve(dstColumns);
for (unsigned r = 0; r < dstRows; ++r) {
@@ -721,9 +724,10 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
}
Value extractedColRhs = extractedCols[c];
Value product =
- createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
- Value sum = vector::ReductionOp::create(
- rewriter, op.getLoc(), vector::CombiningKind::ADD, product);
+ createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter, fmf);
+ Value sum = vector::ReductionOp::create(rewriter, op.getLoc(),
+ vector::CombiningKind::ADD,
+ product, op.getFastmath());
SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
: SmallVector<int64_t, 2>{r, c};
@@ -731,7 +735,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
}
}
if (auto acc = op.getAcc())
- res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
+ res = createAdd(op.getLoc(), res, acc, isInt, rewriter, fmf);
return res;
}
@@ -845,7 +849,8 @@ struct ContractOpToElementwise
newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
std::optional<Value> result =
createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
- contractOp.getKind(), rewriter, isInt);
+ contractOp.getKind(), rewriter, isInt,
+ /*mask=*/Value(), contractOp.getFastmathAttr());
if (result)
return *result;
@@ -1053,8 +1058,9 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
iterIndex, d, rewriter);
- Operation *lowContract = vector::ContractionOp::create(
- rewriter, loc, lhs, rhs, acc, lowAffine, lowIter);
+ Operation *lowContract =
+ vector::ContractionOp::create(rewriter, loc, lhs, rhs, acc, lowAffine,
+ lowIter, op.getKind(), op.getFastmath());
lowContract = maskOperation(rewriter, lowContract, lowMask);
result = reshapeStore(loc, lowContract->getResult(0), result, resType,
resIndex, d, rewriter);
@@ -1099,13 +1105,16 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
if (rhsType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "When LHS has rank 1, expected also RHS to have rank 1");
- Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
+ arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
+ Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter, fmf);
auto kind = vector::CombiningKind::ADD;
Value acc = op.getAcc();
Operation *reductionOp =
- acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc)
- : vector::ReductionOp::create(rewriter, loc, kind, m);
+ acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc,
+ op.getFastmath())
+ : vector::ReductionOp::create(rewriter, loc, kind, m,
+ op.getFastmath());
return maskOperation(rewriter, reductionOp, mask)->getResult(0);
}
// Construct new iterator types and affine map array attribute.
@@ -1130,7 +1139,8 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
iterIndex, d, rewriter);
Operation *newContract = vector::ContractionOp::create(
- rewriter, loc, lhs, rhs, result, lowAffine, lowIter);
+ rewriter, loc, lhs, rhs, result, lowAffine, lowIter, op.getKind(),
+ op.getFastmath());
result = maskOperation(rewriter, newContract, newMask)->getResult(0);
}
return result;
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index 739796099f795..d00fe588f2b5b 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -308,6 +308,132 @@ func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1
return %res : vector<2xi32>
}
+// Verify that fastmath flags on vector.contract propagate to the lowered ops.
+// CHECK-LABEL: func @extract_contract2_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] fastmath<contract> : vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] fastmath<contract> : vector<3xf32> into f32
+// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] fastmath<contract> : vector<3xf32>
+// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] fastmath<contract> : vector<3xf32> into f32
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] fastmath<contract> : vector<2xf32>
+// CHECK: return %[[T10]] : vector<2xf32>
+
+func.func @extract_contract2_fmf(%arg0: vector<2x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract {
+ indexing_maps = #matvec_accesses,
+ iterator_types = ["parallel", "reduction"],
+ fastmath = #arith.fastmath<contract>
+ } %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<3xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// Verify that fastmath flags propagate through matmat (parallel,parallel,reduction) lowering.
+// CHECK-LABEL: func @contract_to_dot_matmat_fmf
+// CHECK-SAME: %[[LHS:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[RHS:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[OUT:.*2]]: vector<2x2xf32>
+// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[RHS_T:.*]] = vector.transpose %[[RHS]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[LHS0:.*]] = vector.extract %[[LHS]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RHS_T0:.*]] = vector.extract %[[RHS_T]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD0:.*]] = arith.mulf %[[LHS0]], %[[RHS_T0]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM0:.*]] = vector.reduction <add>, %[[PROD0]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES0:.*]] = vector.insert %[[SUM0]], %[[INIT]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[RHS_T1:.*]] = vector.extract %[[RHS_T]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD1:.*]] = arith.mulf %[[LHS0]], %[[RHS_T1]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM1:.*]] = vector.reduction <add>, %[[PROD1]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES1:.*]] = vector.insert %[[SUM1]], %[[RES0]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[LHS1:.*]] = vector.extract %[[LHS]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD2:.*]] = arith.mulf %[[LHS1]], %[[RHS_T0]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM2:.*]] = vector.reduction <add>, %[[PROD2]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES2:.*]] = vector.insert %[[SUM2]], %[[RES1]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[PROD3:.*]] = arith.mulf %[[LHS1]], %[[RHS_T1]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM3:.*]] = vector.reduction <add>, %[[PROD3]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES3:.*]] = vector.insert %[[SUM3]], %[[RES2]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[RES:.*]] = arith.addf %[[RES3]], %[[OUT]] fastmath<contract> : vector<2x2xf32>
+// CHECK: return %[[RES]] : vector<2x2xf32>
+
+func.func @contract_to_dot_matmat_fmf(%lhs: vector<2x2xf32>,
+ %rhs: vector<2x2xf32>,
+ %init: vector<2x2xf32>) -> vector<2x2xf32> {
+ %res = vector.contract {
+ indexing_maps = #matmat_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ fastmath = #arith.fastmath<contract>
+ } %lhs, %rhs, %init : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+ return %res : vector<2x2xf32>
+}
+
+// CHECK-LABEL: func @full_contract1_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
+// CHECK-SAME: %[[C:.*2]]: f32
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] fastmath<reassoc> : vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[C]] fastmath<reassoc> : vector<3xf32> into f32
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T6:.*]] = arith.mulf %[[T4]], %[[T5]] fastmath<reassoc> : vector<3xf32>
+// CHECK: %[[T7:.*]] = vector.reduction <add>, %[[T6]], %[[T3]] fastmath<reassoc> : vector<3xf32> into f32
+// CHECK: return %[[T7]] : f32
+
+func.func @full_contract1_fmf(%arg0: vector<2x3xf32>,
+ %arg1: vector<2x3xf32>,
+ %arg2: f32) -> f32 {
+ %0 = vector.contract {
+ indexing_maps = #contraction2d_accesses,
+ iterator_types = ["reduction", "reduction"],
+ fastmath = #arith.fastmath<reassoc>
+ } %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<2x3xf32> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @batch_contract_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[ZERO:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK: %[[A0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[B0:.*]] = vector.extract %[[B]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[C0:.*]] = vector.extract %[[C]][0] : f32 from vector<2xf32>
+// CHECK: %[[M0:.*]] = arith.mulf %[[A0]], %[[B0]] fastmath<reassoc> : vector<2xf32>
+// CHECK: %[[R0:.*]] = vector.reduction <add>, %[[M0]], %[[C0]] fastmath<reassoc> : vector<2xf32> into f32
+// CHECK: %[[V0:.*]] = vector.insert %[[R0]], %[[ZERO]] [0] : f32 into vector<2xf32>
+// CHECK: %[[A1:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[B1:.*]] = vector.extract %[[B]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[C1:.*]] = vector.extract %[[C]][1] : f32 from vector<2xf32>
+// CHECK: %[[M1:.*]] = arith.mulf %[[A1]], %[[B1]] fastmath<reassoc> : vector<2xf32>
+// CHECK: %[[R1:.*]] = vector.reduction <add>, %[[M1]], %[[C1]] fastmath<reassoc> : vector<2xf32> into f32
+// CHECK: %[[V1:.*]] = vector.insert %[[R1]], %[[V0]] [1] : f32 into vector<2xf32>
+// CHECK: return %[[V1]] : vector<2xf32>
+
+#batch_reduce_accesses = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i)>
+]
+
+func.func @batch_contract_fmf(%arg0: vector<2x2xf32>,
+ %arg1: vector<2x2xf32>,
+ %arg2: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract {
+ indexing_maps = #batch_reduce_accesses,
+ iterator_types = ["parallel", "reduction"],
+ fastmath = #arith.fastmath<reassoc>
+ } %arg0, %arg1, %arg2 : vector<2x2xf32>, vector<2x2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
More information about the Mlir-commits
mailing list