[Mlir-commits] [mlir] [mlir][vector] Generalize folding of ext-contractionOp to other types. (PR #96593)

Stanley Winata llvmlistbot at llvm.org
Mon Jun 24 22:01:32 PDT 2024


https://github.com/raikonenfnu created https://github.com/llvm/llvm-project/pull/96593

Many state of the art models and quantization operations are now directly working on vector.contract on integers.

This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines.

>From 5da1ee0b162efa892939a5e3469fc3b289a096af Mon Sep 17 00:00:00 2001
From: Stanley Winata <stanley.winata at amd.com>
Date: Mon, 24 Jun 2024 18:00:41 -0700
Subject: [PATCH] [mlir][vector] Generalize folding of ext-contractionOp to
 other types.

Many state of the art models and quantization operations are now
directly working on vector.contract on integers.

This commit enables generalizes ext-contraction folding S.T we
can emit more performant vector.contracts on codegen pipelines.

Signed-off-by: Stanley Winata <stanley.winata at amd.com>
---
 .../Vector/Transforms/VectorTransforms.cpp    |  9 +++++---
 .../fold-arith-extf-into-vector-contract.mlir | 22 +++++++++++++++++++
 2 files changed, 28 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ea4a02f2f2e77..6dc0e1c1b4bd8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1542,6 +1542,7 @@ struct CanonicalizeContractMatmulToMMT final
 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
 /// This pattern folds the arithmetic extensions into the vector contraction and
 /// enables the usage of native mixed precision Tensor Core instructions.
+template <typename ExtOp>
 struct FoldArithExtIntoContractionOp
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -1549,8 +1550,8 @@ struct FoldArithExtIntoContractionOp
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
 
-    auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
-    auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
+    auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
+    auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
 
     if (!lhsDefOp || !rhsDefOp) {
       return rewriter.notifyMatchFailure(contractOp,
@@ -1804,7 +1805,9 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
 
 void mlir::vector::populateFoldArithExtensionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
+  patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
+               FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
+      patterns.getContext());
 }
 
 void mlir::vector::populateVectorMaskMaterializationPatterns(
diff --git a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
index 31ae126906f21..6dbde7afbdd33 100644
--- a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
+++ b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
@@ -48,3 +48,25 @@ func.func @fold_arith_extf_into_contract_scalable(
       %lhs_f32, %rhs_f32, %arg2 : vector<[64]x64xf32>, vector<64x64xf32> into vector<[64]x64xf32>
     return %result : vector<[64]x64xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @fold_arith_extsi_into_contract
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xi8>, %[[ARG1:.*]]: vector<64x64xi8>, %[[ARG2:.*]]: vector<64x64xi32>)
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xi8>, vector<64x64xi8> into vector<64x64xi32>
+//  CHECK-NEXT:   return %[[R]] : vector<64x64xi32>
+func.func @fold_arith_extsi_into_contract(
+  %arg0: vector<64x64xi8>,
+  %arg1: vector<64x64xi8>,
+  %arg2: vector<64x64xi32>) -> vector<64x64xi32> {
+    %lhs_i32 = arith.extsi %arg0 : vector<64x64xi8> to vector<64x64xi32>
+    %rhs_i32 = arith.extsi %arg1 : vector<64x64xi8> to vector<64x64xi32>
+    %result = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"],
+      kind = #vector.kind<add>}
+      %lhs_i32, %rhs_i32, %arg2 : vector<64x64xi32>, vector<64x64xi32> into vector<64x64xi32>
+    return %result : vector<64x64xi32>
+}



More information about the Mlir-commits mailing list