[Mlir-commits] [mlir] [mlir][Vector] Support mixed mode vector.contract lowering (PR #117753)
Kunwar Grover
llvmlistbot at llvm.org
Tue Dec 3 07:28:17 PST 2024
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/117753
>From ec96f24273cc5c6f9649c46f52325edd2d84a7e8 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 26 Nov 2024 17:33:59 +0000
Subject: [PATCH 1/2] [mlir][Vector] Support mixed mode vector.contract
lowering
---
.../Vector/Transforms/LowerVectorContract.cpp | 57 +++++++++++--------
.../vector-contract-to-dot-transforms.mlir | 27 +++++++++
...contract-to-parallel-arith-transforms.mlir | 18 ++++++
3 files changed, 77 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648f..c8ad2892384995 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -80,6 +80,22 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
}
+Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
+ Type dstElementType) {
+ Type elementType = v.getType();
+ auto vecType = dyn_cast<VectorType>(elementType);
+ if (vecType)
+ elementType = vecType.getElementType();
+ if (elementType == dstElementType)
+ return v;
+ Type promotedType = dstElementType;
+ if (vecType)
+ promotedType = vecType.clone(promotedType);
+ if (isa<FloatType>(dstElementType))
+ return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
+ return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
+}
+
// Helper method to possibly drop a dimension in a load.
// TODO
static Value reshapeLoad(Location loc, Value val, VectorType type,
@@ -136,6 +152,11 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
using vector::CombiningKind;
Value mul;
+ if (acc) {
+ x = promoteToElementType(loc, rewriter, x, getElementTypeOrSelf(acc));
+ y = promoteToElementType(loc, rewriter, y, getElementTypeOrSelf(acc));
+ }
+
if (isInt) {
if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
@@ -413,21 +434,6 @@ struct UnrolledOuterProductGenerator
return rewriter.create<vector::TransposeOp>(loc, v, perm);
}
- Value promote(Value v, Type dstElementType) {
- Type elementType = v.getType();
- auto vecType = dyn_cast<VectorType>(elementType);
- if (vecType)
- elementType = vecType.getElementType();
- if (elementType == dstElementType)
- return v;
- Type promotedType = dstElementType;
- if (vecType)
- promotedType = vecType.clone(promotedType);
- if (isa<FloatType>(dstElementType))
- return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
- return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
- }
-
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
VectorType lhsType, int reductionSize,
std::optional<Value> maybeMask = std::nullopt) {
@@ -439,8 +445,8 @@ struct UnrolledOuterProductGenerator
for (int64_t k = 0; k < reductionSize; ++k) {
Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
- extractA = promote(extractA, resElementType);
- extractB = promote(extractB, resElementType);
+ extractA = promoteToElementType(loc, rewriter, extractA, resElementType);
+ extractB = promoteToElementType(loc, rewriter, extractB, resElementType);
Value extractMask;
if (maybeMask.has_value() && maybeMask.value())
extractMask =
@@ -764,6 +770,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
Value b = rank == 1
? rhs
: rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+ a = promoteToElementType(loc, rewriter, a, getElementTypeOrSelf(dstType));
+ b = promoteToElementType(loc, rewriter, b, getElementTypeOrSelf(dstType));
Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
Value reduced = rewriter.create<vector::ReductionOp>(
op.getLoc(), vector::CombiningKind::ADD, m);
@@ -925,12 +933,6 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
if (failed(filter(op)))
return failure();
- // TODO: support mixed mode contract lowering.
- if (op.getLhsType().getElementType() !=
- getElementTypeOrSelf(op.getAccType()) ||
- op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
- return failure();
-
// TODO: the code below assumes the default contraction, make sure it supports
// other kinds before enabling this lowering.
if (op.getKind() != vector::CombiningKind::ADD) {
@@ -1149,10 +1151,15 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
if (rhsType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "When LHS has rank 1, expected also RHS to have rank 1");
- Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
- auto kind = vector::CombiningKind::ADD;
Value acc = op.getAcc();
+ Value lhs = promoteToElementType(loc, rewriter, op.getLhs(),
+ getElementTypeOrSelf(acc));
+ Value rhs = promoteToElementType(loc, rewriter, op.getRhs(),
+ getElementTypeOrSelf(acc));
+ Value m = createMul(loc, lhs, rhs, isInt, rewriter);
+ auto kind = vector::CombiningKind::ADD;
+
Operation *reductionOp =
acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
: rewriter.create<vector::ReductionOp>(loc, kind, m);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index 0ba185bb847609..3927058a4c6b45 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -295,6 +295,33 @@ func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1
return %res : vector<2xi32>
}
+// CHECK-LABEL: @matmul_mixed
+// CHECK: %[[EXT00:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[EXT01:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL1:.+]] = arith.mulf %[[EXT00]], %[[EXT01]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL1]] : vector<2xf32> into f32
+
+// CHECK: %[[EXT11:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL2:.+]] = arith.mulf %[[EXT00]], %[[EXT11]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL2]] : vector<2xf32> into f32
+
+// CHECK: %[[EXT20:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[EXT21:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL3:.+]] = arith.mulf %[[EXT20]], %[[EXT21]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL3]] : vector<2xf32> into f32
+
+// CHECK: %[[EXT31:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK: %[[MUL4:.+]] = arith.mulf %[[EXT20]], %[[EXT31]] : vector<2xf32>
+// CHECK: vector.reduction <add>, %[[MUL4]] : vector<2xf32> into f32
+
+func.func @matmul_mixed(%arg0: vector<2x2xf16>,
+ %arg1: vector<2x2xf16>,
+ %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+ %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<2x2xf16>, vector<2x2xf16> into vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
index e93c5a08bdc7c9..5d9977e94b1598 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
@@ -51,6 +51,24 @@ func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vect
return %0 : f32
}
+// CHECK-LABEL: func @parallel_contract_lowering_mixed_types
+// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
+// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
+// CHECK: %[[EXT0:.+]] = arith.extf %[[E0]] : f16 to f32
+// CHECK: %[[EXT1:.+]] = arith.extf %[[E1]] : f16 to f32
+// CHECK: %[[M:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32
+// CHECK: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
+// CHECK: return %[[A]] : f32
+func.func @parallel_contract_lowering_mixed_types(%arg0: vector<1x1xf16>, %arg1: vector<1x1xf16>, %arg2: f32) -> f32 {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> ()>],
+ iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2 : vector<1x1xf16>, vector<1x1xf16> into f32
+ return %0 : f32
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
>From 466676c81668a80371d5d78669040ecad2c0ca36 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 3 Dec 2024 15:11:55 +0000
Subject: [PATCH 2/2] Address comments
---
.../Vector/Transforms/LowerVectorContract.cpp | 20 ++++++++++++-------
1 file changed, 13 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index c8ad2892384995..52117cf0da0417 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -80,19 +80,25 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
}
-Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
- Type dstElementType) {
- Type elementType = v.getType();
- auto vecType = dyn_cast<VectorType>(elementType);
- if (vecType)
- elementType = vecType.getElementType();
+static Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
+ Type dstElementType) {
+ Type elementType = getElementTypeOrSelf(v.getType());
if (elementType == dstElementType)
return v;
+
+ // vector.contract only allows extension on operands.
+ assert(elementType.getIntOrFloatBitWidth() <=
+ dstElementType.getIntOrFloatBitWidth() &&
+ "vector.contract does not allow truncation of operands");
+
Type promotedType = dstElementType;
- if (vecType)
+ if (auto vecType = dyn_cast<VectorType>(v.getType()))
promotedType = vecType.clone(promotedType);
+
if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
+ // For integer types, vector.contract only supports signless integer types
+ // and promotion happens via sign extension.
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
}
More information about the Mlir-commits
mailing list