[Mlir-commits] [mlir] 0ff1048 - [mlir][vector] Add transform.apply_patterns.vector.fold_arith_extension

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 28 06:10:45 PDT 2023


Author: Groverkss
Date: 2023-07-28T18:40:31+05:30
New Revision: 0ff1048409da484de1fc912ea1c4552256cb6e6c

URL: https://github.com/llvm/llvm-project/commit/0ff1048409da484de1fc912ea1c4552256cb6e6c
DIFF: https://github.com/llvm/llvm-project/commit/0ff1048409da484de1fc912ea1c4552256cb6e6c.diff

LOG: [mlir][vector] Add transform.apply_patterns.vector.fold_arith_extension

This patch implements a transform op for the FoldArithExtIntoContractionOp
pattern. The pattern folds arith.extf into vector.contract for the
backends with native support for mixed-mode contractions.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D156484

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/test/Dialect/Vector/transform-vector.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 253aeedf15aba5..d0951823f3068d 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,4 +306,15 @@ def ApplyTransferToScfPatternsOp : Op<Transform_Dialect,
   }];
 }
 
+def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.fold_arith_extension",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect a set of patterns that fold arithmetic extension on floating point
+    into vector contract for the backends with native support.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index da99232ed6ab8f..1866de6cb14703 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -32,6 +32,11 @@ void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
   vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
 }
 
+void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateFoldArithExtensionPatterns(patterns);
+}
+
 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorTransferDropUnitDimsPatterns(patterns);

diff  --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index d050885ed30fc1..3e62a8fbf718f9 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -62,3 +62,29 @@ transform.sequence failures(propagate) {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
   } : !transform.any_op
 }
+
+// -----
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @fold_arith_extf_into_contract
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xf16>, %[[ARG1:.*]]: vector<64x64xf16>, %[[ARG2:.*]]: vector<64x64xf32>)
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32>
+//  CHECK-NEXT:   return %[[R]] : vector<64x64xf32>
+func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> {
+    %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32>
+    %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32>
+    %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32>
+    return %result : vector<64x64xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func {
+    transform.apply_patterns.vector.fold_arith_extension
+  } : !transform.any_op
+}


        


More information about the Mlir-commits mailing list