[Mlir-commits] [mlir] [MLIR][Vector] Add fastmath attribute to vector.contract and propagate through lowering (PR #192788)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Apr 18 06:40:22 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Princeton Ferro (Prince781)
<details>
<summary>Changes</summary>
## Summary
`vector.contract` had no fastmath attribute, so fast-math flags set on a
contraction were silently dropped during lowering. This matters for GPU
targets like NVPTX (SM100+) that rely on the `contract` flag to fuse
`fadd(fmul(a,b), acc)` into `fma.rn`.
This patch:
- Adds a `fastmath` attribute (`arith::FastMathFlags`, default `none`) to
`vector::ContractionOp`, with a no-print shorthand when the value is `none`.
- Propagates it through all lowering paths in `LowerVectorContract`:
- **Dot strategy** (`ContractionOpToDotLowering`): forwards FMF to
`arith.mulf` and `vector.reduction<add>`.
- **`lowerParallel`/`lowerReduction`** fallback: forwards `op.getKind()`
and `op.getFastmath()` to recursive `ContractionOp::create` calls.
- **`lowerReduction` base case**: passes `fmf` to `createMul` and
`ReductionOp::create`.
- **`ContractOpToElementwise`**: passes `getFastmathAttr()` to
`createContractArithOp`.
- Adds test cases (`extract_contract2_fmf`, `contract_to_dot_matmat_fmf`,
`full_contract1_fmf`, `batch_contract_fmf`) covering all lowering paths.
---
Full diff: https://github.com/llvm/llvm-project/pull/192788.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+8-2)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+15-3)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+29-19)
- (modified) mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir (+126)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 68ef49172e662..fdde3995f6333 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -52,6 +52,7 @@ def Vector_ContractionOp :
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
PredOpTrait<"third operand acc and result have same element type",
TCresVTEtIsSameAsOpBase<0, 2>>,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
@@ -59,7 +60,10 @@ def Vector_ContractionOp :
ArrayAttr:$indexing_maps,
Vector_IteratorTypeArrayAttr:$iterator_types,
DefaultValuedAttr<Vector_CombiningKindAttr,
- "CombiningKind::ADD">:$kind)>,
+ "CombiningKind::ADD">:$kind,
+ DefaultValuedAttr<
+ Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath)>,
Results<(outs AnyType)> {
let summary = "vector contraction operation";
let description = [{
@@ -180,7 +184,9 @@ def Vector_ContractionOp :
"ArrayRef<IteratorType>":$iteratorTypes)>,
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
"ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes,
- "CombiningKind":$kind)>
+ "CombiningKind":$kind,
+ CArg<"::mlir::arith::FastMathFlags",
+ "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
];
let extraClassDeclaration = [{
VectorType getLhsType() {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3d3e49134363f..2f48cdf2f026f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -818,13 +818,18 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
Value lhs, Value rhs, Value acc,
ArrayAttr indexingMaps,
- ArrayAttr iteratorTypes, CombiningKind kind) {
+ ArrayAttr iteratorTypes, CombiningKind kind,
+ arith::FastMathFlags fastMathFlags) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc.getType());
result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
result.addAttribute(getKindAttrName(result.name),
CombiningKindAttr::get(builder.getContext(), kind));
+ if (fastMathFlags != arith::FastMathFlags::none)
+ result.addAttribute(
+ getFastmathAttrName(result.name),
+ arith::FastMathFlagsAttr::get(builder.getContext(), fastMathFlags));
}
ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -921,8 +926,14 @@ void ContractionOp::print(OpAsmPrinter &p) {
attrs.emplace_back(getIteratorTypesAttrName(),
ArrayAttr::get(getContext(), iteratorTypeNames));
- } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
+ } else if (traitAttrsSet.count(attr.getName().strref()) > 0) {
+ // Omit fastmath when it equals the default (none) to keep output clean.
+ if (attr.getName() == getFastmathAttrName() &&
+ llvm::cast<arith::FastMathFlagsAttr>(attr.getValue()).getValue() ==
+ arith::FastMathFlags::none)
+ continue;
attrs.push_back(attr);
+ }
}
auto dictAttr = DictionaryAttr::get(getContext(), attrs);
@@ -1147,7 +1158,8 @@ Type ContractionOp::getExpectedMaskType() {
SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
return SmallVector<StringRef>{getIndexingMapsAttrName(),
- getIteratorTypesAttrName(), getKindAttrName()};
+ getIteratorTypesAttrName(), getKindAttrName(),
+ getFastmathAttrName()};
}
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 14fbdd2243676..f2eb6097bfec4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -123,7 +123,8 @@ static Value reshapeStore(Location loc, Value val, Value result,
static std::optional<Value>
createContractArithOp(Location loc, Value x, Value y, Value acc,
vector::CombiningKind kind, PatternRewriter &rewriter,
- bool isInt, Value mask = Value()) {
+ bool isInt, Value mask = Value(),
+ arith::FastMathFlagsAttr fmf = {}) {
using vector::CombiningKind;
Value mul;
@@ -150,14 +151,13 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
fma = selectPassthru(rewriter, mask, fma, acc);
return fma;
}
- mul = arith::MulFOp::create(rewriter, loc, x, y);
+ mul = arith::MulFOp::create(rewriter, loc, x, y, fmf);
}
if (!acc)
return std::optional<Value>(mul);
- return makeArithReduction(rewriter, loc, kind, mul, acc,
- /*fastmath=*/nullptr, mask);
+ return makeArithReduction(rewriter, loc, kind, mul, acc, fmf, mask);
}
/// Return the positions of the reductions in the given map.
@@ -184,19 +184,21 @@ static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
/// operands `x` and `y`.
static Value createAdd(Location loc, Value x, Value y, bool isInt,
- PatternRewriter &rewriter) {
+ PatternRewriter &rewriter,
+ arith::FastMathFlagsAttr fmf = {}) {
if (isInt)
return arith::AddIOp::create(rewriter, loc, x, y);
- return arith::AddFOp::create(rewriter, loc, x, y);
+ return arith::AddFOp::create(rewriter, loc, x, y, fmf);
}
/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
/// operands `x and `y`.
static Value createMul(Location loc, Value x, Value y, bool isInt,
- PatternRewriter &rewriter) {
+ PatternRewriter &rewriter,
+ arith::FastMathFlagsAttr fmf = {}) {
if (isInt)
return arith::MulIOp::create(rewriter, loc, x, y);
- return arith::MulFOp::create(rewriter, loc, x, y);
+ return arith::MulFOp::create(rewriter, loc, x, y, fmf);
}
namespace {
@@ -253,12 +255,12 @@ class ContractionOpToOuterProductOpLowering
/// %bt = vector.transpose %b, [1, 0]
/// %aRow0 = vector.extract %a[0]
/// %btRow0 = vector.extract %bt[0]
-/// %c00 = vector.reduction %atRow0, %bRow0
+/// %c00 = vector.reduce %atRow0, %bRow0
/// %out00 = vector.insert %c00, %out[0, 0]
/// ...
/// %aRowLast = vector.extract %at[M-1]
/// %btRowLast = vector.extract %b[N-1]
-/// %cLastLast = vector.reduction %atRowLast, %bRowLast
+/// %cLastLast = vector.reduce %atRowLast, %bRowLast
/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
/// ```
///
@@ -705,6 +707,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
Value res = arith::ConstantOp::create(rewriter, loc, dstType,
rewriter.getZeroAttr(dstType));
bool isInt = isa<IntegerType>(dstType.getElementType());
+ arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
llvm::SmallVector<Value> extractedCols;
extractedCols.reserve(dstColumns);
for (unsigned r = 0; r < dstRows; ++r) {
@@ -721,9 +724,10 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
}
Value extractedColRhs = extractedCols[c];
Value product =
- createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
+ createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter, fmf);
Value sum = vector::ReductionOp::create(
- rewriter, op.getLoc(), vector::CombiningKind::ADD, product);
+ rewriter, op.getLoc(), vector::CombiningKind::ADD, product,
+ op.getFastmath());
SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
: SmallVector<int64_t, 2>{r, c};
@@ -731,7 +735,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
}
}
if (auto acc = op.getAcc())
- res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
+ res = createAdd(op.getLoc(), res, acc, isInt, rewriter, fmf);
return res;
}
@@ -845,7 +849,8 @@ struct ContractOpToElementwise
newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
std::optional<Value> result =
createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
- contractOp.getKind(), rewriter, isInt);
+ contractOp.getKind(), rewriter, isInt,
+ /*mask=*/Value(), contractOp.getFastmathAttr());
if (result)
return *result;
@@ -1054,7 +1059,8 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
iterIndex, d, rewriter);
Operation *lowContract = vector::ContractionOp::create(
- rewriter, loc, lhs, rhs, acc, lowAffine, lowIter);
+ rewriter, loc, lhs, rhs, acc, lowAffine, lowIter,
+ op.getKind(), op.getFastmath());
lowContract = maskOperation(rewriter, lowContract, lowMask);
result = reshapeStore(loc, lowContract->getResult(0), result, resType,
resIndex, d, rewriter);
@@ -1099,13 +1105,16 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
if (rhsType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "When LHS has rank 1, expected also RHS to have rank 1");
- Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
+ arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
+ Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter, fmf);
auto kind = vector::CombiningKind::ADD;
Value acc = op.getAcc();
Operation *reductionOp =
- acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc)
- : vector::ReductionOp::create(rewriter, loc, kind, m);
+ acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc,
+ op.getFastmath())
+ : vector::ReductionOp::create(rewriter, loc, kind, m,
+ op.getFastmath());
return maskOperation(rewriter, reductionOp, mask)->getResult(0);
}
// Construct new iterator types and affine map array attribute.
@@ -1130,7 +1139,8 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
iterIndex, d, rewriter);
Operation *newContract = vector::ContractionOp::create(
- rewriter, loc, lhs, rhs, result, lowAffine, lowIter);
+ rewriter, loc, lhs, rhs, result, lowAffine, lowIter,
+ op.getKind(), op.getFastmath());
result = maskOperation(rewriter, newContract, newMask)->getResult(0);
}
return result;
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index 739796099f795..d00fe588f2b5b 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -308,6 +308,132 @@ func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1
return %res : vector<2xi32>
}
+// Verify that fastmath flags on vector.contract propagate to the lowered ops.
+// CHECK-LABEL: func @extract_contract2_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] fastmath<contract> : vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] fastmath<contract> : vector<3xf32> into f32
+// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] fastmath<contract> : vector<3xf32>
+// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] fastmath<contract> : vector<3xf32> into f32
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] fastmath<contract> : vector<2xf32>
+// CHECK: return %[[T10]] : vector<2xf32>
+
+func.func @extract_contract2_fmf(%arg0: vector<2x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract {
+ indexing_maps = #matvec_accesses,
+ iterator_types = ["parallel", "reduction"],
+ fastmath = #arith.fastmath<contract>
+ } %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<3xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// Verify that fastmath flags propagate through matmat (parallel,parallel,reduction) lowering.
+// CHECK-LABEL: func @contract_to_dot_matmat_fmf
+// CHECK-SAME: %[[LHS:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[RHS:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[OUT:.*2]]: vector<2x2xf32>
+// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[RHS_T:.*]] = vector.transpose %[[RHS]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[LHS0:.*]] = vector.extract %[[LHS]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RHS_T0:.*]] = vector.extract %[[RHS_T]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD0:.*]] = arith.mulf %[[LHS0]], %[[RHS_T0]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM0:.*]] = vector.reduction <add>, %[[PROD0]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES0:.*]] = vector.insert %[[SUM0]], %[[INIT]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[RHS_T1:.*]] = vector.extract %[[RHS_T]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD1:.*]] = arith.mulf %[[LHS0]], %[[RHS_T1]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM1:.*]] = vector.reduction <add>, %[[PROD1]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES1:.*]] = vector.insert %[[SUM1]], %[[RES0]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[LHS1:.*]] = vector.extract %[[LHS]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD2:.*]] = arith.mulf %[[LHS1]], %[[RHS_T0]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM2:.*]] = vector.reduction <add>, %[[PROD2]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES2:.*]] = vector.insert %[[SUM2]], %[[RES1]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[PROD3:.*]] = arith.mulf %[[LHS1]], %[[RHS_T1]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM3:.*]] = vector.reduction <add>, %[[PROD3]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES3:.*]] = vector.insert %[[SUM3]], %[[RES2]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[RES:.*]] = arith.addf %[[RES3]], %[[OUT]] fastmath<contract> : vector<2x2xf32>
+// CHECK: return %[[RES]] : vector<2x2xf32>
+
+func.func @contract_to_dot_matmat_fmf(%lhs: vector<2x2xf32>,
+ %rhs: vector<2x2xf32>,
+ %init: vector<2x2xf32>) -> vector<2x2xf32> {
+ %res = vector.contract {
+ indexing_maps = #matmat_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ fastmath = #arith.fastmath<contract>
+ } %lhs, %rhs, %init : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+ return %res : vector<2x2xf32>
+}
+
+// CHECK-LABEL: func @full_contract1_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
+// CHECK-SAME: %[[C:.*2]]: f32
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] fastmath<reassoc> : vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[C]] fastmath<reassoc> : vector<3xf32> into f32
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T6:.*]] = arith.mulf %[[T4]], %[[T5]] fastmath<reassoc> : vector<3xf32>
+// CHECK: %[[T7:.*]] = vector.reduction <add>, %[[T6]], %[[T3]] fastmath<reassoc> : vector<3xf32> into f32
+// CHECK: return %[[T7]] : f32
+
+func.func @full_contract1_fmf(%arg0: vector<2x3xf32>,
+ %arg1: vector<2x3xf32>,
+ %arg2: f32) -> f32 {
+ %0 = vector.contract {
+ indexing_maps = #contraction2d_accesses,
+ iterator_types = ["reduction", "reduction"],
+ fastmath = #arith.fastmath<reassoc>
+ } %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<2x3xf32> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @batch_contract_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[ZERO:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK: %[[A0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[B0:.*]] = vector.extract %[[B]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[C0:.*]] = vector.extract %[[C]][0] : f32 from vector<2xf32>
+// CHECK: %[[M0:.*]] = arith.mulf %[[A0]], %[[B0]] fastmath<reassoc> : vector<2xf32>
+// CHECK: %[[R0:.*]] = vector.reduction <add>, %[[M0]], %[[C0]] fastmath<reassoc> : vector<2xf32> into f32
+// CHECK: %[[V0:.*]] = vector.insert %[[R0]], %[[ZERO]] [0] : f32 into vector<2xf32>
+// CHECK: %[[A1:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[B1:.*]] = vector.extract %[[B]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[C1:.*]] = vector.extract %[[C]][1] : f32 from vector<2xf32>
+// CHECK: %[[M1:.*]] = arith.mulf %[[A1]], %[[B1]] fastmath<reassoc> : vector<2xf32>
+// CHECK: %[[R1:.*]] = vector.reduction <add>, %[[M1]], %[[C1]] fastmath<reassoc> : vector<2xf32> into f32
+// CHECK: %[[V1:.*]] = vector.insert %[[R1]], %[[V0]] [1] : f32 into vector<2xf32>
+// CHECK: return %[[V1]] : vector<2xf32>
+
+#batch_reduce_accesses = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i)>
+]
+
+func.func @batch_contract_fmf(%arg0: vector<2x2xf32>,
+ %arg1: vector<2x2xf32>,
+ %arg2: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract {
+ indexing_maps = #batch_reduce_accesses,
+ iterator_types = ["parallel", "reduction"],
+ fastmath = #arith.fastmath<reassoc>
+ } %arg0, %arg1, %arg2 : vector<2x2xf32>, vector<2x2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
``````````
</details>
https://github.com/llvm/llvm-project/pull/192788
More information about the Mlir-commits
mailing list