[Mlir-commits] [mlir] [mlir][linalg] Add mixed precision folding pattern in transform op. (PR #148684)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 14 10:35:14 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Md Asghar Ahmad Shahid (shahidact)
<details>
<summary>Changes</summary>
In case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions.
Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support.
This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute 'vectorize_mixed_precision'.
---
Full diff: https://github.com/llvm/llvm-project/pull/148684.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+5)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+12-1)
- (modified) mlir/test/Dialect/Linalg/transform-op-vectorize.mlir (+89)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b4dde776822a1..dc4e6718907f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2347,6 +2347,9 @@ def VectorizeChildrenAndApplyPatternsOp :
operation that is contained inside the vectorization target.
This transformation supports the following attributes:
+ - `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
+ of ops that have mixed precision types. This enables the folding of
+ arith.extFOp/arith.extIOp into vector.contract with mixed precision.
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
loops.
@@ -2367,6 +2370,7 @@ def VectorizeChildrenAndApplyPatternsOp :
}];
let arguments = (ins TransformHandleTypeInterface:$target,
+ UnitAttr:$vectorize_mixed_precision,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
UnitAttr:$flatten_1d_depthwise_conv,
@@ -2380,6 +2384,7 @@ def VectorizeChildrenAndApplyPatternsOp :
let builders = [
OpBuilder<(ins "Value":$target,
+ CArg<"bool", "false">:$vectorizeMixedPrecision,
CArg<"bool", "false">:$vectorizePadding,
CArg<"bool", "false">:$vectorizeNDExtract,
CArg<"bool", "false">:$flatten1DDepthwise)>
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d5f9de465561..c8f256cf38c9d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3784,8 +3784,15 @@ LogicalResult TileUsingForallOp::verify() {
void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
- bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
+ bool vectorizeMixedPrecision, bool vectorizePadding, bool vectorizeExtract,
+ bool flatten1DDepthwiseConv) {
result.addOperands(target);
+ if (vectorizeMixedPrecision) {
+ result.addAttribute(
+ VectorizeChildrenAndApplyPatternsOp::getVectorizeMixedPrecisionAttrName(
+ result.name),
+ builder.getUnitAttr());
+ }
if (vectorizePadding) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
@@ -3876,6 +3883,10 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
patterns.add<CopyVectorizationPattern>(ctx);
+ if (getVectorizeMixedPrecision()) {
+ vector::populateFoldArithExtensionPatterns(patterns);
+ }
+
if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
// This creates an alternative path for lowering tensor.pad - by
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 0d59dbba8940d..96f89653d20ca 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -190,3 +190,92 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Mixed Precision vetorization tests.
+
+// CHECK-LABEL: func @mixed_precision_generic_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
+ %C: memref<8x32xf32>) {
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+ ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
+ outs(%C : memref<8x32xf32>) {
+ ^bb(%in: bf16, %in_0: bf16, %c: f32) :
+ %a = arith.extf %in : bf16 to f32
+ %b = arith.extf %in_0 : bf16 to f32
+ %d = arith.mulf %a, %b: f32
+ %e = arith.addf %c, %d: f32
+ linalg.yield %e : f32
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @mixed_precision_matmul_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
+ %B: tensor<12x25xbf16>,
+ %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
+ outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+ func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @contraction_matmul
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
+ linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
+ outs(%C: memref<1584x1584xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/148684
More information about the Mlir-commits
mailing list