[Mlir-commits] [mlir] 4ff8f11 - [MLIR][Vector] Extend elementwise pattern to support unrolling from higher rank to lower rank (#162515)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 15 06:51:27 PDT 2025


Author: Nishant Patel
Date: 2025-10-15T06:51:23-07:00
New Revision: 4ff8f118cc91870aed357be351230df63ef14dcf

URL: https://github.com/llvm/llvm-project/commit/4ff8f118cc91870aed357be351230df63ef14dcf
DIFF: https://github.com/llvm/llvm-project/commit/4ff8f118cc91870aed357be351230df63ef14dcf.diff

LOG: [MLIR][Vector] Extend elementwise pattern to support unrolling from higher rank to lower rank (#162515)

This PR enhances the elementwise unrolling pattern to support higher
rank to lower rank unroll. The approach is to add leading unit dims to
lower rank targetShape to match the rank of original vector (because
ExtractStridedSlice requires same rank to extractSlices), extract slice,
reshape to targetShape's rank and perform the operation.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
    mlir/test/Dialect/Vector/vector-unroll-options.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 14639c5f1cdd3..fbae0989bed26 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -465,26 +465,33 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto targetShape = getTargetShape(options, op);
     if (!targetShape)
       return failure();
+    int64_t targetShapeRank = targetShape->size();
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
-    // Bail-out if rank(source) != rank(target). The main limitation here is the
-    // fact that `ExtractStridedSlice` requires the rank for the input and
-    // output to match. If needed, we can relax this later.
-    if (originalSize.size() != targetShape->size())
-      return rewriter.notifyMatchFailure(
-          op, "expected input vector rank to match target shape rank");
+    int64_t originalShapeRank = originalSize.size();
+
     Location loc = op->getLoc();
+
+    // Handle rank mismatch by adding leading unit dimensions to targetShape
+    SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
+    int64_t rankDiff = originalShapeRank - targetShapeRank;
+    std::fill(adjustedTargetShape.begin(),
+              adjustedTargetShape.begin() + rankDiff, 1);
+    std::copy(targetShape->begin(), targetShape->end(),
+              adjustedTargetShape.begin() + rankDiff);
+
+    int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
     // Prepare the result vector.
     Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
                                              rewriter.getZeroAttr(dstVecType));
-    SmallVector<int64_t> strides(targetShape->size(), 1);
-    VectorType newVecType =
+    SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
+    VectorType unrolledVecType =
         VectorType::get(*targetShape, dstVecType.getElementType());
 
     // Create the unrolled computation.
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(originalSize, *targetShape)) {
+         StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
       SmallVector<Value> extractOperands;
       for (OpOperand &operand : op->getOpOperands()) {
         auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern {
           extractOperands.push_back(operand.get());
           continue;
         }
-        extractOperands.push_back(
-            rewriter.createOrFold<vector::ExtractStridedSliceOp>(
-                loc, operand.get(), offsets, *targetShape, strides));
+        Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, operand.get(), offsets, adjustedTargetShape, strides);
+
+        // Reshape to remove leading unit dims if needed
+        if (adjustedTargetShapeRank > targetShapeRank) {
+          extracted = rewriter.createOrFold<vector::ShapeCastOp>(
+              loc, VectorType::get(*targetShape, vecType.getElementType()),
+              extracted);
+        }
+        extractOperands.push_back(extracted);
       }
+
       Operation *newOp = cloneOpWithOperandsAndTypes(
-          rewriter, loc, op, extractOperands, newVecType);
+          rewriter, loc, op, extractOperands, unrolledVecType);
+
+      Value computeResult = newOp->getResult(0);
+
+      // Use strides sized to targetShape for proper insertion
+      SmallVector<int64_t> insertStrides =
+          (adjustedTargetShapeRank > targetShapeRank)
+              ? SmallVector<int64_t>(targetShapeRank, 1)
+              : strides;
+
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
-          loc, newOp->getResult(0), result, offsets, strides);
+          loc, computeResult, result, offsets, insertStrides);
     }
     rewriter.replaceOp(op, result);
     return success();

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 35db14e0f7f1d..e5a98b5c67f33 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,15 +188,38 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
 //   CHECK-LABEL: func @vector_fma
 // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
 
-// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
-func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
+func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
   %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
   return %0 : vector<3x2x2xf32>
 }
-// CHECK-LABEL: func @negative_vector_fma_3d
-//   CHECK-NOT: vector.extract_strided_slice
-//       CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-//       CHECK: return
+// CHECK-LABEL: func @vector_fma_3d
+//  CHECK-SAME:   (%[[SRC:.*]]: vector<3x2x2xf32>) -> vector<3x2x2xf32> {
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
+//       CHECK:   %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_OUT_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_OUT_0:.*]] = vector.shape_cast %[[E_OUT_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA0:.*]] = vector.fma %[[S_LHS_0]], %[[S_RHS_0]], %[[S_OUT_0]] : vector<2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_OUT_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_OUT_1:.*]] = vector.shape_cast %[[E_OUT_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA1:.*]] = vector.fma %[[S_LHS_1]], %[[S_RHS_1]], %[[S_OUT_1]] : vector<2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[E_LHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_LHS_2:.*]] = vector.shape_cast %[[E_LHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_RHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_RHS_2:.*]] = vector.shape_cast %[[E_RHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_OUT_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_OUT_2:.*]] = vector.shape_cast %[[E_OUT_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA2:.*]] = vector.fma %[[S_LHS_2]], %[[S_RHS_2]], %[[S_OUT_2]] : vector<2x2xf32>
+//       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   return %[[I2]] : vector<3x2x2xf32>
 
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -440,3 +463,36 @@ func.func @vector_step() -> vector<32xindex> {
 // CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
 // CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
 // CHECK: return %[[INS3]] : vector<32xindex>
+
+
+func.func @elementwise_3D_to_2D(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+  %0 = arith.addf %v1, %v2 : vector<2x2x2xf32>
+  return %0 : vector<2x2x2xf32>
+}
+// CHECK-LABEL: func @elementwise_3D_to_2D
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
+//       CHECK:   %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD0:.*]] = arith.addf %[[S_LHS_0]], %[[S_RHS_0]] : vector<2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD1:.*]] = arith.addf %[[S_LHS_1]], %[[S_RHS_1]] : vector<2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   return %[[I1]] : vector<2x2x2xf32>
+
+
+func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf32>) -> vector<2x2x2x2xf32> {
+  %0 = arith.addf %v1, %v2 : vector<2x2x2x2xf32>
+  return %0 : vector<2x2x2x2xf32>
+}
+
+// CHECK-LABEL: func @elementwise_4D_to_2D
+// CHECK-COUNT-4:   arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-NOT: arith.addf
+// CHECK: return


        


More information about the Mlir-commits mailing list