[Mlir-commits] [mlir] [MLIR][Vector] Add fastmath attribute to vector.contract (PR #192788)

Princeton Ferro llvmlistbot at llvm.org
Sat Apr 18 06:48:14 PDT 2026


https://github.com/Prince781 updated https://github.com/llvm/llvm-project/pull/192788

>From a27947dcc2eeffd56a3719e5774527aa0e64993e Mon Sep 17 00:00:00 2001
From: Princeton Ferro <pferro at nvidia.com>
Date: Sat, 18 Apr 2026 06:47:51 -0700
Subject: [PATCH] fix comment: vector.reduction not vector.reduce

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  10 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  18 ++-
 .../Vector/Transforms/LowerVectorContract.cpp |  48 ++++---
 .../vector-contract-to-dot-transforms.mlir    | 126 ++++++++++++++++++
 4 files changed, 178 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 68ef49172e662..fdde3995f6333 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -52,6 +52,7 @@ def Vector_ContractionOp :
       PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
       PredOpTrait<"third operand acc and result have same element type",
                   TCresVTEtIsSameAsOpBase<0, 2>>,
+      DeclareOpInterfaceMethods<ArithFastMathInterface>,
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]>,
@@ -59,7 +60,10 @@ def Vector_ContractionOp :
                ArrayAttr:$indexing_maps,
                Vector_IteratorTypeArrayAttr:$iterator_types,
                DefaultValuedAttr<Vector_CombiningKindAttr,
-                                 "CombiningKind::ADD">:$kind)>,
+                                 "CombiningKind::ADD">:$kind,
+               DefaultValuedAttr<
+                 Arith_FastMathAttr,
+                 "::mlir::arith::FastMathFlags::none">:$fastmath)>,
     Results<(outs AnyType)> {
   let summary = "vector contraction operation";
   let description = [{
@@ -180,7 +184,9 @@ def Vector_ContractionOp :
       "ArrayRef<IteratorType>":$iteratorTypes)>,
     OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
       "ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes,
-      "CombiningKind":$kind)>
+      "CombiningKind":$kind,
+      CArg<"::mlir::arith::FastMathFlags",
+           "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
   ];
   let extraClassDeclaration = [{
     VectorType getLhsType() {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3d3e49134363f..2f48cdf2f026f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -818,13 +818,18 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
                                   Value lhs, Value rhs, Value acc,
                                   ArrayAttr indexingMaps,
-                                  ArrayAttr iteratorTypes, CombiningKind kind) {
+                                  ArrayAttr iteratorTypes, CombiningKind kind,
+                                  arith::FastMathFlags fastMathFlags) {
   result.addOperands({lhs, rhs, acc});
   result.addTypes(acc.getType());
   result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
   result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
   result.addAttribute(getKindAttrName(result.name),
                       CombiningKindAttr::get(builder.getContext(), kind));
+  if (fastMathFlags != arith::FastMathFlags::none)
+    result.addAttribute(
+        getFastmathAttrName(result.name),
+        arith::FastMathFlagsAttr::get(builder.getContext(), fastMathFlags));
 }
 
 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -921,8 +926,14 @@ void ContractionOp::print(OpAsmPrinter &p) {
 
       attrs.emplace_back(getIteratorTypesAttrName(),
                          ArrayAttr::get(getContext(), iteratorTypeNames));
-    } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
+    } else if (traitAttrsSet.count(attr.getName().strref()) > 0) {
+      // Omit fastmath when it equals the default (none) to keep output clean.
+      if (attr.getName() == getFastmathAttrName() &&
+          llvm::cast<arith::FastMathFlagsAttr>(attr.getValue()).getValue() ==
+              arith::FastMathFlags::none)
+        continue;
       attrs.push_back(attr);
+    }
   }
 
   auto dictAttr = DictionaryAttr::get(getContext(), attrs);
@@ -1147,7 +1158,8 @@ Type ContractionOp::getExpectedMaskType() {
 
 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
   return SmallVector<StringRef>{getIndexingMapsAttrName(),
-                                getIteratorTypesAttrName(), getKindAttrName()};
+                                getIteratorTypesAttrName(), getKindAttrName(),
+                                getFastmathAttrName()};
 }
 
 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 14fbdd2243676..eaf7bb8109514 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -123,7 +123,8 @@ static Value reshapeStore(Location loc, Value val, Value result,
 static std::optional<Value>
 createContractArithOp(Location loc, Value x, Value y, Value acc,
                       vector::CombiningKind kind, PatternRewriter &rewriter,
-                      bool isInt, Value mask = Value()) {
+                      bool isInt, Value mask = Value(),
+                      arith::FastMathFlagsAttr fmf = {}) {
   using vector::CombiningKind;
   Value mul;
 
@@ -150,14 +151,13 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
         fma = selectPassthru(rewriter, mask, fma, acc);
       return fma;
     }
-    mul = arith::MulFOp::create(rewriter, loc, x, y);
+    mul = arith::MulFOp::create(rewriter, loc, x, y, fmf);
   }
 
   if (!acc)
     return std::optional<Value>(mul);
 
-  return makeArithReduction(rewriter, loc, kind, mul, acc,
-                            /*fastmath=*/nullptr, mask);
+  return makeArithReduction(rewriter, loc, kind, mul, acc, fmf, mask);
 }
 
 /// Return the positions of the reductions in the given map.
@@ -184,19 +184,21 @@ static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
 /// operands `x` and `y`.
 static Value createAdd(Location loc, Value x, Value y, bool isInt,
-                       PatternRewriter &rewriter) {
+                       PatternRewriter &rewriter,
+                       arith::FastMathFlagsAttr fmf = {}) {
   if (isInt)
     return arith::AddIOp::create(rewriter, loc, x, y);
-  return arith::AddFOp::create(rewriter, loc, x, y);
+  return arith::AddFOp::create(rewriter, loc, x, y, fmf);
 }
 
 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
 /// operands `x and `y`.
 static Value createMul(Location loc, Value x, Value y, bool isInt,
-                       PatternRewriter &rewriter) {
+                       PatternRewriter &rewriter,
+                       arith::FastMathFlagsAttr fmf = {}) {
   if (isInt)
     return arith::MulIOp::create(rewriter, loc, x, y);
-  return arith::MulFOp::create(rewriter, loc, x, y);
+  return arith::MulFOp::create(rewriter, loc, x, y, fmf);
 }
 
 namespace {
@@ -705,6 +707,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
   Value res = arith::ConstantOp::create(rewriter, loc, dstType,
                                         rewriter.getZeroAttr(dstType));
   bool isInt = isa<IntegerType>(dstType.getElementType());
+  arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
   llvm::SmallVector<Value> extractedCols;
   extractedCols.reserve(dstColumns);
   for (unsigned r = 0; r < dstRows; ++r) {
@@ -721,9 +724,10 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
       }
       Value extractedColRhs = extractedCols[c];
       Value product =
-          createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
-      Value sum = vector::ReductionOp::create(
-          rewriter, op.getLoc(), vector::CombiningKind::ADD, product);
+          createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter, fmf);
+      Value sum = vector::ReductionOp::create(rewriter, op.getLoc(),
+                                              vector::CombiningKind::ADD,
+                                              product, op.getFastmath());
 
       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
                                               : SmallVector<int64_t, 2>{r, c};
@@ -731,7 +735,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
     }
   }
   if (auto acc = op.getAcc())
-    res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
+    res = createAdd(op.getLoc(), res, acc, isInt, rewriter, fmf);
   return res;
 }
 
@@ -845,7 +849,8 @@ struct ContractOpToElementwise
     newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
     std::optional<Value> result =
         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
-                              contractOp.getKind(), rewriter, isInt);
+                              contractOp.getKind(), rewriter, isInt,
+                              /*mask=*/Value(), contractOp.getFastmathAttr());
     if (result)
       return *result;
 
@@ -1053,8 +1058,9 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
       lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
                             iterIndex, d, rewriter);
 
-    Operation *lowContract = vector::ContractionOp::create(
-        rewriter, loc, lhs, rhs, acc, lowAffine, lowIter);
+    Operation *lowContract =
+        vector::ContractionOp::create(rewriter, loc, lhs, rhs, acc, lowAffine,
+                                      lowIter, op.getKind(), op.getFastmath());
     lowContract = maskOperation(rewriter, lowContract, lowMask);
     result = reshapeStore(loc, lowContract->getResult(0), result, resType,
                           resIndex, d, rewriter);
@@ -1099,13 +1105,16 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
     if (rhsType.getRank() != 1)
       return rewriter.notifyMatchFailure(
           op, "When LHS has rank 1, expected also RHS to have rank 1");
-    Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
+    arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
+    Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter, fmf);
     auto kind = vector::CombiningKind::ADD;
 
     Value acc = op.getAcc();
     Operation *reductionOp =
-        acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc)
-            : vector::ReductionOp::create(rewriter, loc, kind, m);
+        acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc,
+                                          op.getFastmath())
+            : vector::ReductionOp::create(rewriter, loc, kind, m,
+                                          op.getFastmath());
     return maskOperation(rewriter, reductionOp, mask)->getResult(0);
   }
   // Construct new iterator types and affine map array attribute.
@@ -1130,7 +1139,8 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
                             iterIndex, d, rewriter);
 
     Operation *newContract = vector::ContractionOp::create(
-        rewriter, loc, lhs, rhs, result, lowAffine, lowIter);
+        rewriter, loc, lhs, rhs, result, lowAffine, lowIter, op.getKind(),
+        op.getFastmath());
     result = maskOperation(rewriter, newContract, newMask)->getResult(0);
   }
   return result;
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index 739796099f795..d00fe588f2b5b 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -308,6 +308,132 @@ func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1
   return %res : vector<2xi32>
 }
 
+// Verify that fastmath flags on vector.contract propagate to the lowered ops.
+// CHECK-LABEL: func @extract_contract2_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK:      %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] fastmath<contract> : vector<3xf32>
+// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]] fastmath<contract> : vector<3xf32> into f32
+// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
+// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK:      %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] fastmath<contract> : vector<3xf32>
+// CHECK:      %[[T8:.*]] = vector.reduction <add>, %[[T7]] fastmath<contract> : vector<3xf32> into f32
+// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK:      %[[T10:.*]] = arith.addf %[[T9]], %[[C]] fastmath<contract> : vector<2xf32>
+// CHECK:      return %[[T10]] : vector<2xf32>
+
+func.func @extract_contract2_fmf(%arg0: vector<2x3xf32>,
+                        %arg1: vector<3xf32>,
+                        %arg2: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract {
+    indexing_maps = #matvec_accesses,
+    iterator_types = ["parallel", "reduction"],
+    fastmath = #arith.fastmath<contract>
+  } %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<3xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// Verify that fastmath flags propagate through matmat (parallel,parallel,reduction) lowering.
+// CHECK-LABEL: func @contract_to_dot_matmat_fmf
+// CHECK-SAME: %[[LHS:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[RHS:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[OUT:.*2]]: vector<2x2xf32>
+// CHECK:    %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK:    %[[RHS_T:.*]] = vector.transpose %[[RHS]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK:    %[[LHS0:.*]] = vector.extract %[[LHS]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:    %[[RHS_T0:.*]] = vector.extract %[[RHS_T]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:    %[[PROD0:.*]] = arith.mulf %[[LHS0]], %[[RHS_T0]] fastmath<contract> : vector<2xf32>
+// CHECK:    %[[SUM0:.*]] = vector.reduction <add>, %[[PROD0]] fastmath<contract> : vector<2xf32> into f32
+// CHECK:    %[[RES0:.*]] = vector.insert %[[SUM0]], %[[INIT]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK:    %[[RHS_T1:.*]] = vector.extract %[[RHS_T]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:    %[[PROD1:.*]] = arith.mulf %[[LHS0]], %[[RHS_T1]] fastmath<contract> : vector<2xf32>
+// CHECK:    %[[SUM1:.*]] = vector.reduction <add>, %[[PROD1]] fastmath<contract> : vector<2xf32> into f32
+// CHECK:    %[[RES1:.*]] = vector.insert %[[SUM1]], %[[RES0]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK:    %[[LHS1:.*]] = vector.extract %[[LHS]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:    %[[PROD2:.*]] = arith.mulf %[[LHS1]], %[[RHS_T0]] fastmath<contract> : vector<2xf32>
+// CHECK:    %[[SUM2:.*]] = vector.reduction <add>, %[[PROD2]] fastmath<contract> : vector<2xf32> into f32
+// CHECK:    %[[RES2:.*]] = vector.insert %[[SUM2]], %[[RES1]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK:    %[[PROD3:.*]] = arith.mulf %[[LHS1]], %[[RHS_T1]] fastmath<contract> : vector<2xf32>
+// CHECK:    %[[SUM3:.*]] = vector.reduction <add>, %[[PROD3]] fastmath<contract> : vector<2xf32> into f32
+// CHECK:    %[[RES3:.*]] = vector.insert %[[SUM3]], %[[RES2]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK:    %[[RES:.*]] = arith.addf %[[RES3]], %[[OUT]] fastmath<contract> : vector<2x2xf32>
+// CHECK:    return %[[RES]] : vector<2x2xf32>
+
+func.func @contract_to_dot_matmat_fmf(%lhs: vector<2x2xf32>,
+                        %rhs: vector<2x2xf32>,
+                        %init: vector<2x2xf32>) -> vector<2x2xf32> {
+  %res = vector.contract {
+    indexing_maps = #matmat_accesses,
+    iterator_types = ["parallel", "parallel", "reduction"],
+    fastmath = #arith.fastmath<contract>
+  } %lhs, %rhs, %init : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+  return %res : vector<2x2xf32>
+}
+
+// CHECK-LABEL: func @full_contract1_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
+// CHECK-SAME: %[[C:.*2]]: f32
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK:      %[[T1:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] fastmath<reassoc> : vector<3xf32>
+// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[C]] fastmath<reassoc> : vector<3xf32> into f32
+// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK:      %[[T5:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK:      %[[T6:.*]] = arith.mulf %[[T4]], %[[T5]] fastmath<reassoc> : vector<3xf32>
+// CHECK:      %[[T7:.*]] = vector.reduction <add>, %[[T6]], %[[T3]] fastmath<reassoc> : vector<3xf32> into f32
+// CHECK:      return %[[T7]] : f32
+
+func.func @full_contract1_fmf(%arg0: vector<2x3xf32>,
+                               %arg1: vector<2x3xf32>,
+                               %arg2: f32) -> f32 {
+  %0 = vector.contract {
+    indexing_maps = #contraction2d_accesses,
+    iterator_types = ["reduction", "reduction"],
+    fastmath = #arith.fastmath<reassoc>
+  } %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<2x3xf32> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: func @batch_contract_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK:      %[[ZERO:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK:      %[[A0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:      %[[B0:.*]] = vector.extract %[[B]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:      %[[C0:.*]] = vector.extract %[[C]][0] : f32 from vector<2xf32>
+// CHECK:      %[[M0:.*]] = arith.mulf %[[A0]], %[[B0]] fastmath<reassoc> : vector<2xf32>
+// CHECK:      %[[R0:.*]] = vector.reduction <add>, %[[M0]], %[[C0]] fastmath<reassoc> : vector<2xf32> into f32
+// CHECK:      %[[V0:.*]] = vector.insert %[[R0]], %[[ZERO]] [0] : f32 into vector<2xf32>
+// CHECK:      %[[A1:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:      %[[B1:.*]] = vector.extract %[[B]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:      %[[C1:.*]] = vector.extract %[[C]][1] : f32 from vector<2xf32>
+// CHECK:      %[[M1:.*]] = arith.mulf %[[A1]], %[[B1]] fastmath<reassoc> : vector<2xf32>
+// CHECK:      %[[R1:.*]] = vector.reduction <add>, %[[M1]], %[[C1]] fastmath<reassoc> : vector<2xf32> into f32
+// CHECK:      %[[V1:.*]] = vector.insert %[[R1]], %[[V0]] [1] : f32 into vector<2xf32>
+// CHECK:      return %[[V1]] : vector<2xf32>
+
+#batch_reduce_accesses = [
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (i)>
+]
+
+func.func @batch_contract_fmf(%arg0: vector<2x2xf32>,
+                               %arg1: vector<2x2xf32>,
+                               %arg2: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract {
+    indexing_maps = #batch_reduce_accesses,
+    iterator_types = ["parallel", "reduction"],
+    fastmath = #arith.fastmath<reassoc>
+  } %arg0, %arg1, %arg2 : vector<2x2xf32>, vector<2x2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op



More information about the Mlir-commits mailing list