[Mlir-commits] [mlir] ac1e22f - [mlir][vector] Generalize folding of ext-contractionOp to other types. (#96593)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 25 09:29:45 PDT 2024
Author: Stanley Winata
Date: 2024-06-25T09:29:43-07:00
New Revision: ac1e22f3053f761e4e2ef832b92de15876e68335
URL: https://github.com/llvm/llvm-project/commit/ac1e22f3053f761e4e2ef832b92de15876e68335
DIFF: https://github.com/llvm/llvm-project/commit/ac1e22f3053f761e4e2ef832b92de15876e68335.diff
LOG: [mlir][vector] Generalize folding of ext-contractionOp to other types. (#96593)
Many state of the art models and quantization operations are now
directly working on vector.contract on integers.
This commit enables generalizes ext-contraction folding S.T we can emit
more performant vector.contracts on codegen pipelines.
Signed-off-by: Stanley Winata <stanley.winata at amd.com>
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index eac6db585aad7..da3d9648cf283 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1552,6 +1552,7 @@ struct CanonicalizeContractMatmulToMMT final
/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
/// This pattern folds the arithmetic extensions into the vector contraction and
/// enables the usage of native mixed precision Tensor Core instructions.
+template <typename ExtOp>
struct FoldArithExtIntoContractionOp
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;
@@ -1559,8 +1560,8 @@ struct FoldArithExtIntoContractionOp
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
- auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
- auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
+ auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
+ auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
if (!lhsDefOp || !rhsDefOp) {
return rewriter.notifyMatchFailure(contractOp,
@@ -1895,7 +1896,9 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
void mlir::vector::populateFoldArithExtensionPatterns(
RewritePatternSet &patterns) {
- patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
+ patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
+ FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
+ patterns.getContext());
}
void mlir::vector::populateVectorMaskMaterializationPatterns(
diff --git a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
index 31ae126906f21..6dbde7afbdd33 100644
--- a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
+++ b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
@@ -48,3 +48,25 @@ func.func @fold_arith_extf_into_contract_scalable(
%lhs_f32, %rhs_f32, %arg2 : vector<[64]x64xf32>, vector<64x64xf32> into vector<[64]x64xf32>
return %result : vector<[64]x64xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_arith_extsi_into_contract
+// CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xi8>, %[[ARG1:.*]]: vector<64x64xi8>, %[[ARG2:.*]]: vector<64x64xi32>)
+// CHECK-NEXT: %[[R:.+]] = vector.contract
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xi8>, vector<64x64xi8> into vector<64x64xi32>
+// CHECK-NEXT: return %[[R]] : vector<64x64xi32>
+func.func @fold_arith_extsi_into_contract(
+ %arg0: vector<64x64xi8>,
+ %arg1: vector<64x64xi8>,
+ %arg2: vector<64x64xi32>) -> vector<64x64xi32> {
+ %lhs_i32 = arith.extsi %arg0 : vector<64x64xi8> to vector<64x64xi32>
+ %rhs_i32 = arith.extsi %arg1 : vector<64x64xi8> to vector<64x64xi32>
+ %result = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %lhs_i32, %rhs_i32, %arg2 : vector<64x64xi32>, vector<64x64xi32> into vector<64x64xi32>
+ return %result : vector<64x64xi32>
+}
More information about the Mlir-commits
mailing list