[Mlir-commits] [mlir] [MLIR][Vector] Add fastmath attribute to vector.contract (PR #192788)
Durgadoss R
llvmlistbot at llvm.org
Wed Apr 22 06:14:41 PDT 2026
================
@@ -308,6 +308,132 @@ func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1
return %res : vector<2xi32>
}
+// Verify that fastmath flags on vector.contract propagate to the lowered ops.
+// CHECK-LABEL: func @extract_contract2_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] fastmath<contract> : vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] fastmath<contract> : vector<3xf32> into f32
+// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] fastmath<contract> : vector<3xf32>
+// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] fastmath<contract> : vector<3xf32> into f32
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] fastmath<contract> : vector<2xf32>
+// CHECK: return %[[T10]] : vector<2xf32>
+
+func.func @extract_contract2_fmf(%arg0: vector<2x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract {
+ indexing_maps = #matvec_accesses,
+ iterator_types = ["parallel", "reduction"],
+ fastmath = #arith.fastmath<contract>
+ } %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<3xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// Verify that fastmath flags propagate through matmat (parallel,parallel,reduction) lowering.
+// CHECK-LABEL: func @contract_to_dot_matmat_fmf
+// CHECK-SAME: %[[LHS:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[RHS:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[OUT:.*2]]: vector<2x2xf32>
+// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[RHS_T:.*]] = vector.transpose %[[RHS]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[LHS0:.*]] = vector.extract %[[LHS]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RHS_T0:.*]] = vector.extract %[[RHS_T]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD0:.*]] = arith.mulf %[[LHS0]], %[[RHS_T0]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM0:.*]] = vector.reduction <add>, %[[PROD0]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES0:.*]] = vector.insert %[[SUM0]], %[[INIT]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[RHS_T1:.*]] = vector.extract %[[RHS_T]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD1:.*]] = arith.mulf %[[LHS0]], %[[RHS_T1]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM1:.*]] = vector.reduction <add>, %[[PROD1]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES1:.*]] = vector.insert %[[SUM1]], %[[RES0]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[LHS1:.*]] = vector.extract %[[LHS]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD2:.*]] = arith.mulf %[[LHS1]], %[[RHS_T0]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM2:.*]] = vector.reduction <add>, %[[PROD2]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES2:.*]] = vector.insert %[[SUM2]], %[[RES1]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[PROD3:.*]] = arith.mulf %[[LHS1]], %[[RHS_T1]] fastmath<contract> : vector<2xf32>
+// CHECK: %[[SUM3:.*]] = vector.reduction <add>, %[[PROD3]] fastmath<contract> : vector<2xf32> into f32
+// CHECK: %[[RES3:.*]] = vector.insert %[[SUM3]], %[[RES2]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[RES:.*]] = arith.addf %[[RES3]], %[[OUT]] fastmath<contract> : vector<2x2xf32>
+// CHECK: return %[[RES]] : vector<2x2xf32>
+
+func.func @contract_to_dot_matmat_fmf(%lhs: vector<2x2xf32>,
+ %rhs: vector<2x2xf32>,
+ %init: vector<2x2xf32>) -> vector<2x2xf32> {
+ %res = vector.contract {
+ indexing_maps = #matmat_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ fastmath = #arith.fastmath<contract>
+ } %lhs, %rhs, %init : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+ return %res : vector<2x2xf32>
+}
+
+// CHECK-LABEL: func @full_contract1_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
+// CHECK-SAME: %[[C:.*2]]: f32
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] fastmath<reassoc> : vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[C]] fastmath<reassoc> : vector<3xf32> into f32
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[T6:.*]] = arith.mulf %[[T4]], %[[T5]] fastmath<reassoc> : vector<3xf32>
+// CHECK: %[[T7:.*]] = vector.reduction <add>, %[[T6]], %[[T3]] fastmath<reassoc> : vector<3xf32> into f32
+// CHECK: return %[[T7]] : f32
+
+func.func @full_contract1_fmf(%arg0: vector<2x3xf32>,
+ %arg1: vector<2x3xf32>,
+ %arg2: f32) -> f32 {
+ %0 = vector.contract {
+ indexing_maps = #contraction2d_accesses,
+ iterator_types = ["reduction", "reduction"],
+ fastmath = #arith.fastmath<reassoc>
+ } %arg0, %arg1, %arg2 : vector<2x3xf32>, vector<2x3xf32> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @batch_contract_fmf
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[ZERO:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK: %[[A0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[B0:.*]] = vector.extract %[[B]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[C0:.*]] = vector.extract %[[C]][0] : f32 from vector<2xf32>
+// CHECK: %[[M0:.*]] = arith.mulf %[[A0]], %[[B0]] fastmath<reassoc> : vector<2xf32>
+// CHECK: %[[R0:.*]] = vector.reduction <add>, %[[M0]], %[[C0]] fastmath<reassoc> : vector<2xf32> into f32
+// CHECK: %[[V0:.*]] = vector.insert %[[R0]], %[[ZERO]] [0] : f32 into vector<2xf32>
+// CHECK: %[[A1:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[B1:.*]] = vector.extract %[[B]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[C1:.*]] = vector.extract %[[C]][1] : f32 from vector<2xf32>
+// CHECK: %[[M1:.*]] = arith.mulf %[[A1]], %[[B1]] fastmath<reassoc> : vector<2xf32>
+// CHECK: %[[R1:.*]] = vector.reduction <add>, %[[M1]], %[[C1]] fastmath<reassoc> : vector<2xf32> into f32
+// CHECK: %[[V1:.*]] = vector.insert %[[R1]], %[[V0]] [1] : f32 into vector<2xf32>
+// CHECK: return %[[V1]] : vector<2xf32>
+
+#batch_reduce_accesses = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i)>
+]
+
+func.func @batch_contract_fmf(%arg0: vector<2x2xf32>,
+ %arg1: vector<2x2xf32>,
+ %arg2: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract {
+ indexing_maps = #batch_reduce_accesses,
+ iterator_types = ["parallel", "reduction"],
+ fastmath = #arith.fastmath<reassoc>
----------------
durga4github wrote:
I was about to ask for one more test feeding a `none` here to verify the `skip' that we added in the Printer method of the contract Op. But, I guess the FileChecks for the existing tests pass without any change, is a sign that the skip works cleanly.
LGTM.
https://github.com/llvm/llvm-project/pull/192788
More information about the Mlir-commits
mailing list