[Mlir-commits] [mlir] [mlir][Vector] Support mixed mode vector.contract lowering (PR #117753)
Kunwar Grover
llvmlistbot at llvm.org
Tue Nov 26 09:36:50 PST 2024
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/117753
This patch adds mixed-mode contract support. The implementation follows the documentation of vector.contract:
https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop
> If operands and the result have types of different bitwidths, operands are promoted to have the same bitwidth as the result before performing the contraction. For integer types, only signless integer types are supported, and the promotion happens via sign extension.
>From cee8bec63feb521b6e570e30b571a68e03d390d8 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 26 Nov 2024 17:33:59 +0000
Subject: [PATCH] [mlir][Vector] Support mixed mode vector.contract lowering
---
.../Vector/Transforms/LowerVectorContract.cpp | 57 +++++++++++--------
.../vector-contract-to-dot-transforms.mlir | 27 +++++++++
...contract-to-parallel-arith-transforms.mlir | 18 ++++++
3 files changed, 77 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648f..c8ad2892384995 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -80,6 +80,22 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
}
+Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
+ Type dstElementType) {
+ Type elementType = v.getType();
+ auto vecType = dyn_cast<VectorType>(elementType);
+ if (vecType)
+ elementType = vecType.getElementType();
+ if (elementType == dstElementType)
+ return v;
+ Type promotedType = dstElementType;
+ if (vecType)
+ promotedType = vecType.clone(promotedType);
+ if (isa<FloatType>(dstElementType))
+ return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
+ return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
+}
+
// Helper method to possibly drop a dimension in a load.
// TODO
static Value reshapeLoad(Location loc, Value val, VectorType type,
@@ -136,6 +152,11 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
using vector::CombiningKind;
Value mul;
+ if (acc) {
+ x = promoteToElementType(loc, rewriter, x, getElementTypeOrSelf(acc));
+ y = promoteToElementType(loc, rewriter, y, getElementTypeOrSelf(acc));
+ }
+
if (isInt) {
if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
@@ -413,21 +434,6 @@ struct UnrolledOuterProductGenerator
return rewriter.create<vector::TransposeOp>(loc, v, perm);
}
- Value promote(Value v, Type dstElementType) {
- Type elementType = v.getType();
- auto vecType = dyn_cast<VectorType>(elementType);
- if (vecType)
- elementType = vecType.getElementType();
- if (elementType == dstElementType)
- return v;
- Type promotedType = dstElementType;
- if (vecType)
- promotedType = vecType.clone(promotedType);
- if (isa<FloatType>(dstElementType))
- return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
- return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
- }
-
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
VectorType lhsType, int reductionSize,
std::optional<Value> maybeMask = std::nullopt) {
@@ -439,8 +445,8 @@ struct UnrolledOuterProductGenerator
for (int64_t k = 0; k < reductionSize; ++k) {
Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
- extractA = promote(extractA, resElementType);
- extractB = promote(extractB, resElementType);
+ extractA = promoteToElementType(loc, rewriter, extractA, resElementType);
+ extractB = promoteToElementType(loc, rewriter, extractB, resElementType);
Value extractMask;
if (maybeMask.has_value() && maybeMask.value())
extractMask =
@@ -764,6 +770,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
Value b = rank == 1
? rhs
: rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+ a = promoteToElementType(loc, rewriter, a, getElementTypeOrSelf(dstType));
+ b = promoteToElementType(loc, rewriter, b, getElementTypeOrSelf(dstType));
Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
Value reduced = rewriter.create<vector::ReductionOp>(
op.getLoc(), vector::CombiningKind::ADD, m);
@@ -925,12 +933,6 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
if (failed(filter(op)))
return failure();
- // TODO: support mixed mode contract lowering.
- if (op.getLhsType().getElementType() !=
- getElementTypeOrSelf(op.getAccType()) ||
- op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
- return failure();
-
// TODO: the code below assumes the default contraction, make sure it supports
// other kinds before enabling this lowering.
if (op.getKind() != vector::CombiningKind::ADD) {
@@ -1149,10 +1151,15 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
if (rhsType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "When LHS has rank 1, expected also RHS to have rank 1");
- Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
- auto kind = vector::CombiningKind::ADD;
Value acc = op.getAcc();
+ Value lhs = promoteToElementType(loc, rewriter, op.getLhs(),
+ getElementTypeOrSelf(acc));
+ Value rhs = promoteToElementType(loc, rewriter, op.getRhs(),
+ getElementTypeOrSelf(acc));
+ Value m = createMul(loc, lhs, rhs, isInt, rewriter);
+ auto kind = vector::CombiningKind::ADD;
+
Operation *reductionOp =
acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
: rewriter.create<vector::ReductionOp>(loc, kind, m);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index 0ba185bb847609..3927058a4c6b45 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -295,6 +295,33 @@ func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1
return %res : vector<2xi32>
}
+// CHECK-LABEL: @matmul_mixed
+// CHECK: %[[EXT00:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[EXT01:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL1:.+]] = arith.mulf %[[EXT00]], %[[EXT01]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL1]] : vector<2xf32> into f32
+
+// CHECK: %[[EXT11:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL2:.+]] = arith.mulf %[[EXT00]], %[[EXT11]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL2]] : vector<2xf32> into f32
+
+// CHECK: %[[EXT20:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[EXT21:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL3:.+]] = arith.mulf %[[EXT20]], %[[EXT21]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL3]] : vector<2xf32> into f32
+
+// CHECK: %[[EXT31:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL4:.+]] = arith.mulf %[[EXT20]], %[[EXT31]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL4]] : vector<2xf32> into f32
+
+func.func @matmul_mixed(%arg0: vector<2x2xf16>,
+ %arg1: vector<2x2xf16>,
+ %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+ %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<2x2xf16>, vector<2x2xf16> into vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
index e93c5a08bdc7c9..5d9977e94b1598 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
@@ -51,6 +51,24 @@ func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vect
return %0 : f32
}
+// CHECK-LABEL: func @parallel_contract_lowering_mixed_types
+// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
+// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
+// CHECK: %[[EXT0:.+]] = arith.extf %[[E0]] : f16 to f32
+// CHECK: %[[EXT1:.+]] = arith.extf %[[E1]] : f16 to f32
+// CHECK: %[[M:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32
+// CHECK: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
+// CHECK: return %[[A]] : f32
+func.func @parallel_contract_lowering_mixed_types(%arg0: vector<1x1xf16>, %arg1: vector<1x1xf16>, %arg2: f32) -> f32 {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> ()>],
+ iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2 : vector<1x1xf16>, vector<1x1xf16> into f32
+ return %0 : f32
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
More information about the Mlir-commits
mailing list