[Mlir-commits] [mlir] [mlir][linalg] Add mixed precision folding pattern in transform op. (PR #148684)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Mon Jul 14 10:34:42 PDT 2025


https://github.com/shahidact created https://github.com/llvm/llvm-project/pull/148684

In case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions.

Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support.

This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute 'vectorize_mixed_precision'.

>From 1034cb591cabcb3b2683b8d05db453c90a17190b Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Mon, 14 Jul 2025 09:39:06 -0700
Subject: [PATCH] [mlir][linalg] Add mixed precision folding pattern in
 transform op.

In case of mixed precision inputs, the inputs are generally casted to
match output type thereby introduces arith.extFOp/extIOp instructions.

Folding such pattern into vector.contract is desirable for HW having
mixed precision ISA support.

This patch adds folding of mixed precision pattern into vector.contract
optionaly which can be enabled using attribute 'vectorize_mixed_precision'.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  5 ++
 .../TransformOps/LinalgTransformOps.cpp       | 13 ++-
 .../Linalg/transform-op-vectorize.mlir        | 89 +++++++++++++++++++
 3 files changed, 106 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b4dde776822a1..dc4e6718907f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2347,6 +2347,9 @@ def VectorizeChildrenAndApplyPatternsOp :
     operation that is contained inside the vectorization target.
 
     This transformation supports the following attributes:
+    - `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
+      of ops that have mixed precision types. This enables the folding of
+      arith.extFOp/arith.extIOp into vector.contract with mixed precision.
     - `vectorize_padding`: a `UnitAttr` to activate the vectorization of
       `tensor.pad` ops. Different pipelines may prefer to lower such ops to
       loops.
@@ -2367,6 +2370,7 @@ def VectorizeChildrenAndApplyPatternsOp :
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
+                   UnitAttr:$vectorize_mixed_precision,
                    UnitAttr:$vectorize_padding,
                    UnitAttr:$vectorize_nd_extract,
                    UnitAttr:$flatten_1d_depthwise_conv,
@@ -2380,6 +2384,7 @@ def VectorizeChildrenAndApplyPatternsOp :
 
   let builders = [
     OpBuilder<(ins "Value":$target,
+               CArg<"bool", "false">:$vectorizeMixedPrecision,
                CArg<"bool", "false">:$vectorizePadding,
                CArg<"bool", "false">:$vectorizeNDExtract,
                CArg<"bool", "false">:$flatten1DDepthwise)>
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d5f9de465561..c8f256cf38c9d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3784,8 +3784,15 @@ LogicalResult TileUsingForallOp::verify() {
 
 void transform::VectorizeChildrenAndApplyPatternsOp::build(
     OpBuilder &builder, OperationState &result, Value target,
-    bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
+    bool vectorizeMixedPrecision, bool vectorizePadding, bool vectorizeExtract,
+    bool flatten1DDepthwiseConv) {
   result.addOperands(target);
+  if (vectorizeMixedPrecision) {
+    result.addAttribute(
+        VectorizeChildrenAndApplyPatternsOp::getVectorizeMixedPrecisionAttrName(
+            result.name),
+        builder.getUnitAttr());
+  }
   if (vectorizePadding) {
     result.addAttribute(
         VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
@@ -3876,6 +3883,10 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
 
   patterns.add<CopyVectorizationPattern>(ctx);
 
+  if (getVectorizeMixedPrecision()) {
+    vector::populateFoldArithExtensionPatterns(patterns);
+  }
+
   if (getVectorizePadding()) {
     linalg::populatePadOpVectorizationPatterns(patterns);
     // This creates an alternative path for lowering tensor.pad - by
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 0d59dbba8940d..96f89653d20ca 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -190,3 +190,92 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Mixed Precision vetorization tests.
+
+// CHECK-LABEL: func @mixed_precision_generic_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT:     arith.extf
+// CHECK:         vector.contract
+// CHECK:         vector.transfer_write
+func.func @mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
+                         %C: memref<8x32xf32>) {
+  linalg.generic {
+    indexing_maps = [
+    affine_map<(m, n, k) -> (m, k)>,
+    affine_map<(m, n, k) -> (k, n)>,
+    affine_map<(m, n, k) -> (m, n)>
+    ],
+    iterator_types = ["parallel", "parallel", "reduction"]
+  }
+   ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
+   outs(%C : memref<8x32xf32>) {
+    ^bb(%in: bf16, %in_0: bf16, %c: f32) :
+      %a = arith.extf %in : bf16 to f32
+      %b = arith.extf %in_0 : bf16 to f32
+      %d = arith.mulf %a, %b: f32
+      %e = arith.addf %c, %d: f32
+      linalg.yield %e : f32
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1  { vectorize_mixed_precision, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @mixed_precision_matmul_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT:     arith.extf
+// CHECK:         vector.contract
+// CHECK:         vector.transfer_write
+func.func @mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
+                              %B: tensor<12x25xbf16>,
+                              %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  %0 = linalg.contract
+      indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+                       affine_map<(m, n, k) -> (k, n)>,
+                       affine_map<(m, n, k) -> (m, n)>]
+      ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
+      outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @contraction_matmul
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT:     arith.extf
+// CHECK:         vector.contract
+func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
+  linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
+            outs(%C: memref<1584x1584xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1  { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list