[Mlir-commits] [mlir] [mlir][linalg] Add mixed precision folding pattern in vectorize_children_and_apply_patterns TD Op (PR #148684)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Tue Aug 5 04:37:05 PDT 2025
https://github.com/shahidact updated https://github.com/llvm/llvm-project/pull/148684
>From 851ec2a388be6b822cc850040a162faf8083bbbd Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Mon, 14 Jul 2025 09:39:06 -0700
Subject: [PATCH 1/2] [mlir][linalg] Add mixed precision folding pattern in
transform op.
In case of mixed precision inputs, the inputs are generally casted to
match output type thereby introduces arith.extFOp/extIOp instructions.
Folding such pattern into vector.contract is desirable for HW having
mixed precision ISA support.
This patch adds folding of mixed precision pattern into vector.contract
optionaly which can be enabled using attribute 'vectorize_mixed_precision'.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 5 ++
.../TransformOps/LinalgTransformOps.cpp | 13 ++-
.../Linalg/transform-op-vectorize.mlir | 89 +++++++++++++++++++
3 files changed, 106 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 61ce23f07faa8..02da2ce30f215 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2348,6 +2348,9 @@ def VectorizeChildrenAndApplyPatternsOp :
operation that is contained inside the vectorization target.
This transformation supports the following attributes:
+ - `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
+ of ops that have mixed precision types. This enables the folding of
+ arith.extFOp/arith.extIOp into vector.contract with mixed precision.
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
loops.
@@ -2368,6 +2371,7 @@ def VectorizeChildrenAndApplyPatternsOp :
}];
let arguments = (ins TransformHandleTypeInterface:$target,
+ UnitAttr:$vectorize_mixed_precision,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
UnitAttr:$flatten_1d_depthwise_conv,
@@ -2381,6 +2385,7 @@ def VectorizeChildrenAndApplyPatternsOp :
let builders = [
OpBuilder<(ins "Value":$target,
+ CArg<"bool", "false">:$vectorizeMixedPrecision,
CArg<"bool", "false">:$vectorizePadding,
CArg<"bool", "false">:$vectorizeNDExtract,
CArg<"bool", "false">:$flatten1DDepthwise)>
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bdfc8d020e58f..416052dd28500 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3783,8 +3783,15 @@ LogicalResult TileUsingForallOp::verify() {
void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
- bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
+ bool vectorizeMixedPrecision, bool vectorizePadding, bool vectorizeExtract,
+ bool flatten1DDepthwiseConv) {
result.addOperands(target);
+ if (vectorizeMixedPrecision) {
+ result.addAttribute(
+ VectorizeChildrenAndApplyPatternsOp::getVectorizeMixedPrecisionAttrName(
+ result.name),
+ builder.getUnitAttr());
+ }
if (vectorizePadding) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
@@ -3875,6 +3882,10 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
patterns.add<CopyVectorizationPattern>(ctx);
+ if (getVectorizeMixedPrecision()) {
+ vector::populateFoldArithExtensionPatterns(patterns);
+ }
+
if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
// This creates an alternative path for lowering tensor.pad - by
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 0d59dbba8940d..96f89653d20ca 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -190,3 +190,92 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Mixed Precision vetorization tests.
+
+// CHECK-LABEL: func @mixed_precision_generic_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
+ %C: memref<8x32xf32>) {
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+ ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
+ outs(%C : memref<8x32xf32>) {
+ ^bb(%in: bf16, %in_0: bf16, %c: f32) :
+ %a = arith.extf %in : bf16 to f32
+ %b = arith.extf %in_0 : bf16 to f32
+ %d = arith.mulf %a, %b: f32
+ %e = arith.addf %c, %d: f32
+ linalg.yield %e : f32
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @mixed_precision_matmul_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
+ %B: tensor<12x25xbf16>,
+ %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
+ outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+ func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @contraction_matmul
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
+ linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
+ outs(%C: memref<1584x1584xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
>From 0c11c3990a7310b21cf77d3a8676ec0e3a35f1d3 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Tue, 5 Aug 2025 04:33:31 -0700
Subject: [PATCH 2/2] -Moved the tests to approprite place and added few more
tests. -Refactored some code and comments.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 7 +-
.../TransformOps/LinalgTransformOps.cpp | 13 +-
.../Linalg/transform-op-vectorize.mlir | 91 +---------
.../linalg-ops-with-patterns.mlir | 155 ++++++++++++++++++
4 files changed, 165 insertions(+), 101 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 02da2ce30f215..8e9ba50a61416 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2348,8 +2348,7 @@ def VectorizeChildrenAndApplyPatternsOp :
operation that is contained inside the vectorization target.
This transformation supports the following attributes:
- - `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
- of ops that have mixed precision types. This enables the folding of
+ - `fold_mixed_precision_into_contract`: a `UnitAttr` to enable the folding of
arith.extFOp/arith.extIOp into vector.contract with mixed precision.
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
@@ -2371,7 +2370,7 @@ def VectorizeChildrenAndApplyPatternsOp :
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- UnitAttr:$vectorize_mixed_precision,
+ UnitAttr:$fold_mixed_precision_into_contract,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
UnitAttr:$flatten_1d_depthwise_conv,
@@ -2385,7 +2384,7 @@ def VectorizeChildrenAndApplyPatternsOp :
let builders = [
OpBuilder<(ins "Value":$target,
- CArg<"bool", "false">:$vectorizeMixedPrecision,
+ CArg<"bool", "false">:$foldMixedPrecisionIntoContract,
CArg<"bool", "false">:$vectorizePadding,
CArg<"bool", "false">:$vectorizeNDExtract,
CArg<"bool", "false">:$flatten1DDepthwise)>
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 416052dd28500..13c84a79c227c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3783,13 +3783,13 @@ LogicalResult TileUsingForallOp::verify() {
void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
- bool vectorizeMixedPrecision, bool vectorizePadding, bool vectorizeExtract,
- bool flatten1DDepthwiseConv) {
+ bool foldMixedPrecisionIntoContract, bool vectorizePadding,
+ bool vectorizeExtract, bool flatten1DDepthwiseConv) {
result.addOperands(target);
- if (vectorizeMixedPrecision) {
+ if (foldMixedPrecisionIntoContract) {
result.addAttribute(
- VectorizeChildrenAndApplyPatternsOp::getVectorizeMixedPrecisionAttrName(
- result.name),
+ VectorizeChildrenAndApplyPatternsOp::
+ getFoldMixedPrecisionIntoContractAttrName(result.name),
builder.getUnitAttr());
}
if (vectorizePadding) {
@@ -3882,9 +3882,8 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
patterns.add<CopyVectorizationPattern>(ctx);
- if (getVectorizeMixedPrecision()) {
+ if (getFoldMixedPrecisionIntoContract())
vector::populateFoldArithExtensionPatterns(patterns);
- }
if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 96f89653d20ca..e0c5cddfaf30b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -189,93 +189,4 @@ module attributes {transform.with_named_sequence} {
%2 = transform.structured.vectorize_children_and_apply_patterns %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
-}
-
-// -----
-
-// Mixed Precision vetorization tests.
-
-// CHECK-LABEL: func @mixed_precision_generic_as_contract
-// CHECK-COUNT-3: vector.transfer_read
-// CHECK-NOT: arith.extf
-// CHECK: vector.contract
-// CHECK: vector.transfer_write
-func.func @mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
- %C: memref<8x32xf32>) {
- linalg.generic {
- indexing_maps = [
- affine_map<(m, n, k) -> (m, k)>,
- affine_map<(m, n, k) -> (k, n)>,
- affine_map<(m, n, k) -> (m, n)>
- ],
- iterator_types = ["parallel", "parallel", "reduction"]
- }
- ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
- outs(%C : memref<8x32xf32>) {
- ^bb(%in: bf16, %in_0: bf16, %c: f32) :
- %a = arith.extf %in : bf16 to f32
- %b = arith.extf %in_0 : bf16 to f32
- %d = arith.mulf %a, %b: f32
- %e = arith.addf %c, %d: f32
- linalg.yield %e : f32
- }
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @mixed_precision_matmul_as_contract
-// CHECK-COUNT-3: vector.transfer_read
-// CHECK-NOT: arith.extf
-// CHECK: vector.contract
-// CHECK: vector.transfer_write
-func.func @mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
- %B: tensor<12x25xbf16>,
- %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
- %0 = linalg.contract
- indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
- affine_map<(m, n, k) -> (k, n)>,
- affine_map<(m, n, k) -> (m, n)>]
- ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
- outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
- func.return %0 : tensor<24x25xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: @contraction_matmul
-// CHECK-COUNT-3: vector.transfer_read
-// CHECK-NOT: arith.extf
-// CHECK: vector.contract
-func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
- linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
- outs(%C: memref<1584x1584xf32>)
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 4eeae4c064519..86a09682ec159 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -1777,3 +1777,158 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Mixed precision vectorization tests.
+
+// CHECK-LABEL: func @float_mixed_precision_generic_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @float_mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
+ %C: memref<8x32xf32>) {
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+ ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
+ outs(%C : memref<8x32xf32>) {
+ ^bb(%in: bf16, %in_0: bf16, %c: f32) :
+ %a = arith.extf %in : bf16 to f32
+ %b = arith.extf %in_0 : bf16 to f32
+ %d = arith.mulf %a, %b: f32
+ %e = arith.addf %c, %d: f32
+ linalg.yield %e : f32
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @integer_mixed_precision_generic_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extsi
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @integer_mixed_precision_generic_as_contract(%A: memref<8x16xi8>, %B: memref<16x32xi8>,
+ %C: memref<8x32xi32>) {
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+ ins(%A, %B : memref<8x16xi8>, memref<16x32xi8>)
+ outs(%C : memref<8x32xi32>) {
+ ^bb(%in: i8, %in_0: i8, %c: i32) :
+ %a = arith.extsi %in : i8 to i32
+ %b = arith.extsi %in_0 : i8 to i32
+ %d = arith.muli %a, %b: i32
+ %e = arith.addi %c, %d: i32
+ linalg.yield %e : i32
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @float_mixed_precision_matmul_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @float_mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
+ %B: tensor<12x25xbf16>,
+ %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
+ outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+ func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @integer_mixed_precision_matmul_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @integer_mixed_precision_matmul_as_contract(%A: tensor<24x12xi8>,
+ %B: tensor<12x25xi8>,
+ %C: tensor<24x25xi32>) -> tensor<24x25xi32> {
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<24x12xi8>, tensor<12x25xi8>)
+ outs(%C : tensor<24x25xi32>) -> tensor<24x25xi32>
+ func.return %0 : tensor<24x25xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @contraction_matmul
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
+ linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
+ outs(%C: memref<1584x1584xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list