[Mlir-commits] [mlir] [mlir] Extract RHS rows once when lowering vector.contract to dot (PR #130130)
Artemiy Bulavin
llvmlistbot at llvm.org
Tue Mar 11 04:27:35 PDT 2025
https://github.com/abulavin updated https://github.com/llvm/llvm-project/pull/130130
>From 8e4544f35cc83b792e6d5588ef3e2890c613e723 Mon Sep 17 00:00:00 2001
From: Artemiy Bulavin <artemiyb at graphcore.ai>
Date: Wed, 5 Mar 2025 19:25:24 +0000
Subject: [PATCH 1/5] Extract RHS rows once when lowering vector.contract to
dot
---
.../Vector/Transforms/LowerVectorContract.cpp | 14 +++++++++++---
.../Vector/vector-contract-to-dot-transforms.mlir | 6 ++----
2 files changed, 13 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648..88227ecb6c675 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -758,12 +758,20 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
rewriter.getZeroAttr(dstType));
bool isInt = isa<IntegerType>(dstType.getElementType());
+ llvm::SmallVector<Value> extractedCols;
+ extractedCols.reserve(dstColumns);
for (unsigned r = 0; r < dstRows; ++r) {
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
for (unsigned c = 0; c < dstColumns; ++c) {
- Value b = rank == 1
- ? rhs
- : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+ if (r == 0) {
+ // We only need to extract the rows of the RHS once
+ // and then re-use them later.
+ Value b = rank == 1
+ ? rhs
+ : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+ extractedCols.push_back(b);
+ }
+ Value b = extractedCols[c];
Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
Value reduced = rewriter.create<vector::ReductionOp>(
op.getLoc(), vector::CombiningKind::ADD, m);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index 0ba185bb84760..f77027c3c2d95 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -169,13 +169,11 @@ func.func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
//
// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32>
+// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T2]] : vector<2xf32>
// CHECK: %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32
// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
//
-// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32>
+// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T12]] : vector<2xf32>
// CHECK: %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32
// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
//
>From ce9c7201a97d6e4117862c540b587011afe1d620 Mon Sep 17 00:00:00 2001
From: Artemiy Bulavin <artemiyb at graphcore.ai>
Date: Thu, 6 Mar 2025 20:07:06 +0000
Subject: [PATCH 2/5] fixup: Refactor variable names in vector contract to dot
transform test
---
.../vector-contract-to-dot-transforms.mlir | 73 +++++++++++--------
1 file changed, 44 insertions(+), 29 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index f77027c3c2d95..739796099f795 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -151,41 +151,56 @@ func.func @extract_contract3(%arg0: vector<3xf32>,
iterator_types = ["parallel", "parallel", "reduction"]
}
-// CHECK-LABEL: func @extract_contract4
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
-// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
-// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
-// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
-// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
-// CHECK: %[[T10:.*]] = vector.reduction <add>, %[[T9]] : vector<2xf32> into f32
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK-LABEL: func @contract_to_dot_matmat
+// CHECK-SAME: %[[LHS:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[RHS:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[OUT:.*2]]: vector<2x2xf32>
//
-// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32>
-// CHECK: %[[T20:.*]] = vector.reduction <add>, %[[T19]] : vector<2xf32> into f32
-// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
+// The `vector.contract` to dot lowering will 'unroll' a matrix-matrix
+// multiplication into individual dot products betweem rows of the LHS with columns
+// of the RHS. In the following test we expect 4 extract-dotproduct-insert sequences of
+// ops that correspond to the 4 dot products resulting from unrolling a matmul between
+// two matrices of size (2, 2).
//
-// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T2]] : vector<2xf32>
-// CHECK: %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32
-// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
//
-// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T12]] : vector<2xf32>
-// CHECK: %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32
-// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
+// First, The RHS will be transposed to make it easier to extract individual columns
+// using vector.extract.
//
-// CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32>
-// CHECK: return %[[T52]] : vector<2x2xf32>
+// CHECK: %[[RHS_T:.*]] = vector.transpose %[[RHS]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+//
+// Next, we expect 4 sequences of extracting rows of the RHS, LHS, performing a dot
+// product and then inserting it into the result.
+//
+// CHECK: %[[LHS0:.*]] = vector.extract %[[LHS]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RHS_T0:.*]] = vector.extract %[[RHS_T]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD0:.*]] = arith.mulf %[[LHS0]], %[[RHS_T0]] : vector<2xf32>
+// CHECK: %[[SUM0:.*]] = vector.reduction <add>, %[[PROD0]] : vector<2xf32> into f32
+// CHECK: %[[RES0:.*]] = vector.insert %[[SUM0]], %[[INIT]] [0, 0] : f32 into vector<2x2xf32>
+//
+// CHECK: %[[RHS_T1:.*]] = vector.extract %[[RHS_T]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD1:.*]] = arith.mulf %[[LHS0]], %[[RHS_T1]] : vector<2xf32>
+// CHECK: %[[SUM1:.*]] = vector.reduction <add>, %[[PROD1]] : vector<2xf32> into f32
+// CHECK: %[[RES1:.*]] = vector.insert %[[SUM1]], %[[RES0]] [0, 1] : f32 into vector<2x2xf32>
+//
+// CHECK: %[[LHS1:.*]] = vector.extract %[[LHS]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[PROD2:.*]] = arith.mulf %[[LHS1]], %[[RHS_T0]] : vector<2xf32>
+// CHECK: %[[SUM2:.*]] = vector.reduction <add>, %[[PROD2]] : vector<2xf32> into f32
+// CHECK: %[[RES2:.*]] = vector.insert %[[SUM2]], %[[RES1]] [1, 0] : f32 into vector<2x2xf32>
+//
+// CHECK: %[[PROD3:.*]] = arith.mulf %[[LHS1]], %[[RHS_T1]] : vector<2xf32>
+// CHECK: %[[SUM3:.*]] = vector.reduction <add>, %[[PROD3]] : vector<2xf32> into f32
+// CHECK: %[[RES3:.*]] = vector.insert %[[SUM3]], %[[RES2]] [1, 1] : f32 into vector<2x2xf32>
+//
+// CHECK: %[[RES:.*]] = arith.addf %[[RES3]], %[[OUT]] : vector<2x2xf32>
+// CHECK: return %[[RES]] : vector<2x2xf32>
-func.func @extract_contract4(%arg0: vector<2x2xf32>,
- %arg1: vector<2x2xf32>,
- %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
- %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+func.func @contract_to_dot_matmat(%lhs: vector<2x2xf32>,
+ %rhs: vector<2x2xf32>,
+ %init: vector<2x2xf32>) -> vector<2x2xf32> {
+ %res = vector.contract #matmat_trait %lhs, %rhs, %init
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
- return %0 : vector<2x2xf32>
+ return %res : vector<2x2xf32>
}
>From 4018f2367d28fada9303af424e4aeb4aa00f49fd Mon Sep 17 00:00:00 2001
From: Artemiy Bulavin <artemiyb at graphcore.ai>
Date: Thu, 6 Mar 2025 20:07:34 +0000
Subject: [PATCH 3/5] fixup: refer to columns instead of rows in reuse comment
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 88227ecb6c675..cdbf8644418a3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -764,7 +764,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
for (unsigned c = 0; c < dstColumns; ++c) {
if (r == 0) {
- // We only need to extract the rows of the RHS once
+ // We only need to extract the columns of the RHS once
// and then re-use them later.
Value b = rank == 1
? rhs
>From 4c74d73ac63d2ff6c30e9e1c6a3e59b77bd4a774 Mon Sep 17 00:00:00 2001
From: Artemiy Bulavin <artemiyb at graphcore.ai>
Date: Tue, 11 Mar 2025 11:21:47 +0000
Subject: [PATCH 4/5] fixup: Rephrase comment and rename column/row variables
---
.../Vector/Transforms/LowerVectorContract.cpp | 23 ++++++++++---------
1 file changed, 12 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index cdbf8644418a3..6bb3d56937129 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -761,24 +761,25 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
llvm::SmallVector<Value> extractedCols;
extractedCols.reserve(dstColumns);
for (unsigned r = 0; r < dstRows; ++r) {
- Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
+ Value rowLhs = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
for (unsigned c = 0; c < dstColumns; ++c) {
+ // Extract each respective row and column of the LHS and RHS once to
+ // avoid having duplicate SSA values pointing to the same rows/columns.
if (r == 0) {
- // We only need to extract the columns of the RHS once
- // and then re-use them later.
- Value b = rank == 1
- ? rhs
+ Value colRhs =
+ rank == 1 ? rhs
: rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
- extractedCols.push_back(b);
+ extractedCols.push_back(colRhs);
}
- Value b = extractedCols[c];
- Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
- Value reduced = rewriter.create<vector::ReductionOp>(
- op.getLoc(), vector::CombiningKind::ADD, m);
+ Value extractedColRhs = extractedCols[c];
+ Value product =
+ createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
+ Value sum = rewriter.create<vector::ReductionOp>(
+ op.getLoc(), vector::CombiningKind::ADD, product);
SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
: SmallVector<int64_t, 2>{r, c};
- res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
+ res = rewriter.create<vector::InsertOp>(op.getLoc(), sum, res, pos);
}
}
if (auto acc = op.getAcc())
>From 86853bf2df45fda7c114ca01311eb37f7cd30dc2 Mon Sep 17 00:00:00 2001
From: Artemiy Bulavin <artemiyb at graphcore.ai>
Date: Tue, 11 Mar 2025 11:27:13 +0000
Subject: [PATCH 5/5] fixup: remove trailing whitespace
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6bb3d56937129..902adae6feeb1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -764,7 +764,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
Value rowLhs = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
for (unsigned c = 0; c < dstColumns; ++c) {
// Extract each respective row and column of the LHS and RHS once to
- // avoid having duplicate SSA values pointing to the same rows/columns.
+ // avoid having duplicate SSA values pointing to the same rows/columns.
if (r == 0) {
Value colRhs =
rank == 1 ? rhs
More information about the Mlir-commits
mailing list