[Mlir-commits] [mlir] fc127ff - [mlir] Extract RHS rows once when lowering vector.contract to dot (#130130)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 12 10:16:53 PDT 2025


Author: Artemiy Bulavin
Date: 2025-03-12T17:16:49Z
New Revision: fc127ff53d0c816e9e9a64ef55868479e0b84ebd

URL: https://github.com/llvm/llvm-project/commit/fc127ff53d0c816e9e9a64ef55868479e0b84ebd
DIFF: https://github.com/llvm/llvm-project/commit/fc127ff53d0c816e9e9a64ef55868479e0b84ebd.diff

LOG: [mlir] Extract RHS rows once when lowering vector.contract to dot (#130130)

The `vector.contract` op on two matrices A and B will be lowered to
individual dot products of each row and column of A and B respectively.
The existing lowering will extract each column of B for each row of A,
which leads to multiple values in the IR representing the same columns
of B.

This PR makes changes to the `ContractOpToDotLowering` to make sure that
the columns of B are only ever extracted once, so then the SSA values
representing the extracted columns are then re-used in the IR for later
dot products.

I have updated the existing vector-contract-to-dot-transforms test.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
    mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index c74d0622b3828..c6627b5ec0d77 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -757,19 +757,28 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
                                                  rewriter.getZeroAttr(dstType));
   bool isInt = isa<IntegerType>(dstType.getElementType());
+  llvm::SmallVector<Value> extractedCols;
+  extractedCols.reserve(dstColumns);
   for (unsigned r = 0; r < dstRows; ++r) {
-    Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
+    Value rowLhs = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
     for (unsigned c = 0; c < dstColumns; ++c) {
-      Value b = rank == 1
-                    ? rhs
-                    : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
-      Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
-      Value reduced = rewriter.create<vector::ReductionOp>(
-          op.getLoc(), vector::CombiningKind::ADD, m);
+      // Extract each respective row and column of the LHS and RHS once to
+      // avoid having duplicate SSA values pointing to the same rows/columns.
+      if (r == 0) {
+        Value colRhs =
+            rank == 1 ? rhs
+                      : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+        extractedCols.push_back(colRhs);
+      }
+      Value extractedColRhs = extractedCols[c];
+      Value product =
+          createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
+      Value sum = rewriter.create<vector::ReductionOp>(
+          op.getLoc(), vector::CombiningKind::ADD, product);
 
       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
                                               : SmallVector<int64_t, 2>{r, c};
-      res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
+      res = rewriter.create<vector::InsertOp>(op.getLoc(), sum, res, pos);
     }
   }
   if (auto acc = op.getAcc())

diff  --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
index 0ba185bb84760..739796099f795 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
@@ -151,43 +151,56 @@ func.func @extract_contract3(%arg0: vector<3xf32>,
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
-// CHECK-LABEL: func @extract_contract4
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
-// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
-// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
-// CHECK:    %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
-// CHECK:    %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK:    %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK:    %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK:    %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
-// CHECK:    %[[T10:.*]] = vector.reduction <add>, %[[T9]] : vector<2xf32> into f32
-// CHECK:    %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK-LABEL: func @contract_to_dot_matmat
+// CHECK-SAME: %[[LHS:.*0]]: vector<2x2xf32>,
+// CHECK-SAME: %[[RHS:.*1]]: vector<2x2xf32>,
+// CHECK-SAME: %[[OUT:.*2]]: vector<2x2xf32>
 //
-// CHECK:    %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK:    %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32>
-// CHECK:    %[[T20:.*]] = vector.reduction <add>, %[[T19]] : vector<2xf32> into f32
-// CHECK:    %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
+// The `vector.contract` to dot lowering will 'unroll' a matrix-matrix
+// multiplication into individual dot products betweem rows of the LHS with columns
+// of the RHS. In the following test we expect 4 extract-dotproduct-insert sequences of
+// ops that correspond to the 4 dot products resulting from unrolling a matmul between
+// two matrices of size (2, 2).
 //
-// CHECK:    %[[T23:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK:    %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK:    %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32>
-// CHECK:    %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32
-// CHECK:    %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK:    %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
 //
-// CHECK:    %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK:    %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32>
-// CHECK:    %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32
-// CHECK:    %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
+// First, The RHS will be transposed to make it easier to extract individual columns
+// using vector.extract.
 //
-// CHECK:    %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32>
-// CHECK:    return %[[T52]] : vector<2x2xf32>
+// CHECK:    %[[RHS_T:.*]] = vector.transpose %[[RHS]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+//
+// Next, we expect 4 sequences of extracting rows of the RHS, LHS, performing a dot
+// product and then inserting it into the result.
+//
+// CHECK:    %[[LHS0:.*]] = vector.extract %[[LHS]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:    %[[RHS_T0:.*]] = vector.extract %[[RHS_T]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:    %[[PROD0:.*]] = arith.mulf %[[LHS0]], %[[RHS_T0]] : vector<2xf32>
+// CHECK:    %[[SUM0:.*]] = vector.reduction <add>, %[[PROD0]] : vector<2xf32> into f32
+// CHECK:    %[[RES0:.*]] = vector.insert %[[SUM0]], %[[INIT]] [0, 0] : f32 into vector<2x2xf32>
+//
+// CHECK:    %[[RHS_T1:.*]] = vector.extract %[[RHS_T]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:    %[[PROD1:.*]] = arith.mulf %[[LHS0]], %[[RHS_T1]] : vector<2xf32>
+// CHECK:    %[[SUM1:.*]] = vector.reduction <add>, %[[PROD1]] : vector<2xf32> into f32
+// CHECK:    %[[RES1:.*]] = vector.insert %[[SUM1]], %[[RES0]] [0, 1] : f32 into vector<2x2xf32>
+//
+// CHECK:    %[[LHS1:.*]] = vector.extract %[[LHS]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:    %[[PROD2:.*]] = arith.mulf %[[LHS1]], %[[RHS_T0]] : vector<2xf32>
+// CHECK:    %[[SUM2:.*]] = vector.reduction <add>, %[[PROD2]] : vector<2xf32> into f32
+// CHECK:    %[[RES2:.*]] = vector.insert %[[SUM2]], %[[RES1]] [1, 0] : f32 into vector<2x2xf32>
+//
+// CHECK:    %[[PROD3:.*]] = arith.mulf %[[LHS1]], %[[RHS_T1]] : vector<2xf32>
+// CHECK:    %[[SUM3:.*]] = vector.reduction <add>, %[[PROD3]] : vector<2xf32> into f32
+// CHECK:    %[[RES3:.*]] = vector.insert %[[SUM3]], %[[RES2]] [1, 1] : f32 into vector<2x2xf32>
+//
+// CHECK:    %[[RES:.*]] = arith.addf %[[RES3]], %[[OUT]] : vector<2x2xf32>
+// CHECK:    return %[[RES]] : vector<2x2xf32>
 
-func.func @extract_contract4(%arg0: vector<2x2xf32>,
-                        %arg1: vector<2x2xf32>,
-                        %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
-  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+func.func @contract_to_dot_matmat(%lhs: vector<2x2xf32>,
+                        %rhs: vector<2x2xf32>,
+                        %init: vector<2x2xf32>) -> vector<2x2xf32> {
+  %res = vector.contract #matmat_trait %lhs, %rhs, %init
     : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-  return %0 : vector<2x2xf32>
+  return %res : vector<2x2xf32>
 }
 
 


        


More information about the Mlir-commits mailing list