[Mlir-commits] [mlir] Extract RHS rows once when lowering vector.contract to dot (PR #130130)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 6 08:11:50 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Artemiy Bulavin (abulavin)

<details>
<summary>Changes</summary>

The `vector.contract` op on two matrices A and B will be lowered to individual dot products of each row and column of A and B respectively. The existing lowering will extract each column of B for each column of A, which leads to multiple values in the IR representing the same columns of B.

This PR makes changes to the `ContractOpToDotLowering` to make sure that the columns of B are only ever extracted once, and then that the SSA values representing the extracted columns are then re-used in the IR for later dot products.

I have update the existing vector-contract-to-dot-transforms test.

---
Full diff: https://github.com/llvm/llvm-project/pull/130130.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+11-3) 
- (modified) mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir (+2-4) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648..88227ecb6c675 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -758,12 +758,20 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
                                                  rewriter.getZeroAttr(dstType));
   bool isInt = isa<IntegerType>(dstType.getElementType());
+  llvm::SmallVector<Value> extractedCols;
+  extractedCols.reserve(dstColumns);
   for (unsigned r = 0; r < dstRows; ++r) {
     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
     for (unsigned c = 0; c < dstColumns; ++c) {
-      Value b = rank == 1
-                    ? rhs
-                    : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+      if (r == 0) {
+        // We only need to extract the rows of the RHS once
+        // and then re-use them later.
+        Value b = rank == 1
+                      ? rhs
+                      : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+        extractedCols.push_back(b);
+      }
+      Value b = extractedCols[c];
       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
       Value reduced = rewriter.create<vector::ReductionOp>(
           op.getLoc(), vector::CombiningKind::ADD, m);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index 0ba185bb84760..f77027c3c2d95 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -169,13 +169,11 @@ func.func @extract_contract3(%arg0: vector<3xf32>,
 // CHECK:    %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
 //
 // CHECK:    %[[T23:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK:    %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK:    %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32>
+// CHECK:    %[[T32:.*]] = arith.mulf %[[T23]], %[[T2]] : vector<2xf32>
 // CHECK:    %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32
 // CHECK:    %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
 //
-// CHECK:    %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK:    %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32>
+// CHECK:    %[[T41:.*]] = arith.mulf %[[T23]], %[[T12]] : vector<2xf32>
 // CHECK:    %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32
 // CHECK:    %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
 //

``````````

</details>


https://github.com/llvm/llvm-project/pull/130130


More information about the Mlir-commits mailing list