[Mlir-commits] [mlir] 0ff1048 - [mlir][vector] Add transform.apply_patterns.vector.fold_arith_extension
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 28 06:10:45 PDT 2023
Author: Groverkss
Date: 2023-07-28T18:40:31+05:30
New Revision: 0ff1048409da484de1fc912ea1c4552256cb6e6c
URL: https://github.com/llvm/llvm-project/commit/0ff1048409da484de1fc912ea1c4552256cb6e6c
DIFF: https://github.com/llvm/llvm-project/commit/0ff1048409da484de1fc912ea1c4552256cb6e6c.diff
LOG: [mlir][vector] Add transform.apply_patterns.vector.fold_arith_extension
This patch implements a transform op for the FoldArithExtIntoContractionOp
pattern. The pattern folds arith.extf into vector.contract for the
backends with native support for mixed-mode contractions.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D156484
Added:
Modified:
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
mlir/test/Dialect/Vector/transform-vector.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 253aeedf15aba5..d0951823f3068d 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,4 +306,15 @@ def ApplyTransferToScfPatternsOp : Op<Transform_Dialect,
}];
}
+def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.fold_arith_extension",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect a set of patterns that fold arithmetic extension on floating point
+ into vector contract for the backends with native support.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index da99232ed6ab8f..1866de6cb14703 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -32,6 +32,11 @@ void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
}
+void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateFoldArithExtensionPatterns(patterns);
+}
+
void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorTransferDropUnitDimsPatterns(patterns);
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index d050885ed30fc1..3e62a8fbf718f9 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -62,3 +62,29 @@ transform.sequence failures(propagate) {
transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
} : !transform.any_op
}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @fold_arith_extf_into_contract
+// CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xf16>, %[[ARG1:.*]]: vector<64x64xf16>, %[[ARG2:.*]]: vector<64x64xf32>)
+// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32>
+// CHECK-NEXT: return %[[R]] : vector<64x64xf32>
+func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> {
+ %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32>
+ %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32>
+ %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32>
+ return %result : vector<64x64xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.fold_arith_extension
+ } : !transform.any_op
+}
More information about the Mlir-commits
mailing list