[Mlir-commits] [mlir] [mlir][Vector] Support mixed mode vector.contract lowering (PR #117753)

Kunwar Grover llvmlistbot at llvm.org
Tue Dec 3 07:28:17 PST 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/117753

>From ec96f24273cc5c6f9649c46f52325edd2d84a7e8 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 26 Nov 2024 17:33:59 +0000
Subject: [PATCH 1/2] [mlir][Vector] Support mixed mode vector.contract
 lowering

---
 .../Vector/Transforms/LowerVectorContract.cpp | 57 +++++++++++--------
 .../vector-contract-to-dot-transforms.mlir    | 27 +++++++++
 ...contract-to-parallel-arith-transforms.mlir | 18 ++++++
 3 files changed, 77 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648f..c8ad2892384995 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -80,6 +80,22 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
 }
 
+Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
+                           Type dstElementType) {
+  Type elementType = v.getType();
+  auto vecType = dyn_cast<VectorType>(elementType);
+  if (vecType)
+    elementType = vecType.getElementType();
+  if (elementType == dstElementType)
+    return v;
+  Type promotedType = dstElementType;
+  if (vecType)
+    promotedType = vecType.clone(promotedType);
+  if (isa<FloatType>(dstElementType))
+    return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
+  return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
+}
+
 // Helper method to possibly drop a dimension in a load.
 // TODO
 static Value reshapeLoad(Location loc, Value val, VectorType type,
@@ -136,6 +152,11 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
   using vector::CombiningKind;
   Value mul;
 
+  if (acc) {
+    x = promoteToElementType(loc, rewriter, x, getElementTypeOrSelf(acc));
+    y = promoteToElementType(loc, rewriter, y, getElementTypeOrSelf(acc));
+  }
+
   if (isInt) {
     if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
         kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
@@ -413,21 +434,6 @@ struct UnrolledOuterProductGenerator
     return rewriter.create<vector::TransposeOp>(loc, v, perm);
   }
 
-  Value promote(Value v, Type dstElementType) {
-    Type elementType = v.getType();
-    auto vecType = dyn_cast<VectorType>(elementType);
-    if (vecType)
-      elementType = vecType.getElementType();
-    if (elementType == dstElementType)
-      return v;
-    Type promotedType = dstElementType;
-    if (vecType)
-      promotedType = vecType.clone(promotedType);
-    if (isa<FloatType>(dstElementType))
-      return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
-    return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
-  }
-
   FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
                              VectorType lhsType, int reductionSize,
                              std::optional<Value> maybeMask = std::nullopt) {
@@ -439,8 +445,8 @@ struct UnrolledOuterProductGenerator
     for (int64_t k = 0; k < reductionSize; ++k) {
       Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
       Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
-      extractA = promote(extractA, resElementType);
-      extractB = promote(extractB, resElementType);
+      extractA = promoteToElementType(loc, rewriter, extractA, resElementType);
+      extractB = promoteToElementType(loc, rewriter, extractB, resElementType);
       Value extractMask;
       if (maybeMask.has_value() && maybeMask.value())
         extractMask =
@@ -764,6 +770,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
       Value b = rank == 1
                     ? rhs
                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+      a = promoteToElementType(loc, rewriter, a, getElementTypeOrSelf(dstType));
+      b = promoteToElementType(loc, rewriter, b, getElementTypeOrSelf(dstType));
       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
       Value reduced = rewriter.create<vector::ReductionOp>(
           op.getLoc(), vector::CombiningKind::ADD, m);
@@ -925,12 +933,6 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
   if (failed(filter(op)))
     return failure();
 
-  // TODO: support mixed mode contract lowering.
-  if (op.getLhsType().getElementType() !=
-          getElementTypeOrSelf(op.getAccType()) ||
-      op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
-    return failure();
-
   // TODO: the code below assumes the default contraction, make sure it supports
   // other kinds before enabling this lowering.
   if (op.getKind() != vector::CombiningKind::ADD) {
@@ -1149,10 +1151,15 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
     if (rhsType.getRank() != 1)
       return rewriter.notifyMatchFailure(
           op, "When LHS has rank 1, expected also RHS to have rank 1");
-    Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
-    auto kind = vector::CombiningKind::ADD;
 
     Value acc = op.getAcc();
+    Value lhs = promoteToElementType(loc, rewriter, op.getLhs(),
+                                     getElementTypeOrSelf(acc));
+    Value rhs = promoteToElementType(loc, rewriter, op.getRhs(),
+                                     getElementTypeOrSelf(acc));
+    Value m = createMul(loc, lhs, rhs, isInt, rewriter);
+    auto kind = vector::CombiningKind::ADD;
+
     Operation *reductionOp =
         acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
             : rewriter.create<vector::ReductionOp>(loc, kind, m);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index 0ba185bb847609..3927058a4c6b45 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -295,6 +295,33 @@ func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1
   return %res : vector<2xi32>
 }
 
+// CHECK-LABEL: @matmul_mixed
+// CHECK:  %[[EXT00:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[EXT01:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[MUL1:.+]] = arith.mulf %[[EXT00]], %[[EXT01]] : vector<2xf32>
+// CHECK:  vector.reduction <add>, %[[MUL1]] : vector<2xf32> into f32
+
+// CHECK:  %[[EXT11:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[MUL2:.+]] = arith.mulf %[[EXT00]], %[[EXT11]] : vector<2xf32>
+// CHECK:  vector.reduction <add>, %[[MUL2]] : vector<2xf32> into f32
+
+// CHECK:  %[[EXT20:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[EXT21:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[MUL3:.+]] = arith.mulf %[[EXT20]], %[[EXT21]] : vector<2xf32>
+// CHECK:  vector.reduction <add>, %[[MUL3]] : vector<2xf32> into f32
+
+// CHECK:  %[[EXT31:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
+// CHECK:  %[[MUL4:.+]] = arith.mulf %[[EXT20]], %[[EXT31]] : vector<2xf32>
+// CHECK:  vector.reduction <add>, %[[MUL4]] : vector<2xf32> into f32
+
+func.func @matmul_mixed(%arg0: vector<2x2xf16>,
+                        %arg1: vector<2x2xf16>,
+                        %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<2x2xf16>, vector<2x2xf16> into vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
index e93c5a08bdc7c9..5d9977e94b1598 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
@@ -51,6 +51,24 @@ func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vect
   return %0 : f32
 }
 
+// CHECK-LABEL: func @parallel_contract_lowering_mixed_types
+//       CHECK:   %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
+//       CHECK:   %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
+//       CHECK:   %[[EXT0:.+]] = arith.extf %[[E0]] : f16 to f32
+//       CHECK:   %[[EXT1:.+]] = arith.extf %[[E1]] : f16 to f32
+//       CHECK:   %[[M:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32
+//       CHECK:   %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
+//       CHECK:   return %[[A]] : f32
+func.func @parallel_contract_lowering_mixed_types(%arg0: vector<1x1xf16>, %arg1: vector<1x1xf16>, %arg2: f32) -> f32 {
+  %0 = vector.contract {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> ()>],
+    iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
+  %arg0, %arg1, %arg2 : vector<1x1xf16>, vector<1x1xf16> into f32
+  return %0 : f32
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

>From 466676c81668a80371d5d78669040ecad2c0ca36 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 3 Dec 2024 15:11:55 +0000
Subject: [PATCH 2/2] Address comments

---
 .../Vector/Transforms/LowerVectorContract.cpp | 20 ++++++++++++-------
 1 file changed, 13 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index c8ad2892384995..52117cf0da0417 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -80,19 +80,25 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
 }
 
-Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
-                           Type dstElementType) {
-  Type elementType = v.getType();
-  auto vecType = dyn_cast<VectorType>(elementType);
-  if (vecType)
-    elementType = vecType.getElementType();
+static Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
+                                  Type dstElementType) {
+  Type elementType = getElementTypeOrSelf(v.getType());
   if (elementType == dstElementType)
     return v;
+
+  // vector.contract only allows extension on operands.
+  assert(elementType.getIntOrFloatBitWidth() <=
+             dstElementType.getIntOrFloatBitWidth() &&
+         "vector.contract does not allow truncation of operands");
+
   Type promotedType = dstElementType;
-  if (vecType)
+  if (auto vecType = dyn_cast<VectorType>(v.getType()))
     promotedType = vecType.clone(promotedType);
+
   if (isa<FloatType>(dstElementType))
     return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
+  // For integer types, vector.contract only supports signless integer types
+  // and promotion happens via sign extension.
   return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
 }
 



More information about the Mlir-commits mailing list