[Mlir-commits] [mlir] [mlr][vector] Add more patterns to Vector Linearize transformation (PR #136193)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 17 13:28:01 PDT 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff HEAD~1 HEAD --extensions cpp -- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 6de5d0c5a..0c3cd18ea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -282,22 +282,24 @@ private:
 /// source vector using ExtractStridedSliceOp and inserting them into the
 /// destination vector using InsertStridedSliceOp.
 /// Following,
-///   vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
+///   vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into
+///   vector<4x4xf32>
 /// is converted to :
-///   %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
-///   %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
-///   %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
-///   %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+///   %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]}
+///   : vector<4xf32> from vector<8xf32> %1 = vector.insert_strided_slice %0, %d
+///   {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32> %2 =
+///   vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} :
+///   vector<4xf32> from vector<8xf32> %3 = vector.insert_strided_slice %2, %1
+///   {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
 struct LinearizeVectorInsertStridedSlice final
     : public OpConversionPattern<vector::InsertStridedSliceOp> {
-  using OpConversionPattern<
-      vector::InsertStridedSliceOp>::OpConversionPattern;
-      LinearizeVectorInsertStridedSlice(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+  using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
+  LinearizeVectorInsertStridedSlice(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
   matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
@@ -345,8 +347,9 @@ struct LinearizeVectorInsertStridedSlice final
     rewriter.replaceOp(op, dstValue);
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ShuffleOp that works on nD (n > 1)
@@ -619,22 +622,22 @@ private:
 /// is converted to :
 ///   %result = arith.constant dense<0.0> : vector<4x4xf32>
 ///   %slice_0 = vector.load %base[%indices] : vector<4xf32>
-///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
-///   %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
-///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into
+///   vector<4x4xf32> %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into
+///   vector<4x4xf32>
 ///   ...
 /// This unrolls the 2D vector load into multiple 1D vector loads and inserts
 /// them into the result vector. The pattern currently supports only 2D vectors
-struct LinearizeVectorLoad final
-    : public OpConversionPattern<vector::LoadOp> {
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
 
   LinearizeVectorLoad(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
   matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
@@ -648,35 +651,33 @@ struct LinearizeVectorLoad final
     }
     auto unrollCount = shape[0];
     auto vecSize = shape[1];
-    auto newVecType =
-        VectorType::get({vecSize}, vecType.getElementType());
+    auto newVecType = VectorType::get({vecSize}, vecType.getElementType());
 
     llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
     Value xBaseIndex = indices[0];
 
     // Construct the 2D vector.
-    Value resultVec = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getZeroAttr(vecType));
+    Value resultVec =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecType));
     // Emit unrolled loads for each 1D vector slice.
     for (auto i = 0; i < unrollCount; i++) {
       Value xIndex = xBaseIndex;
       if (i) {
         auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
-        xIndex =
-            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+        xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
       }
       indices[0] = xIndex;
-      auto vec = rewriter.create<vector::LoadOp>(
-          loc, newVecType, adaptor.getBase(), indices);
-      resultVec =
-          rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+      auto vec = rewriter.create<vector::LoadOp>(loc, newVecType,
+                                                 adaptor.getBase(), indices);
+      resultVec = rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
     }
 
     rewriter.replaceOp(loadOp, resultVec);
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
@@ -689,19 +690,19 @@ struct LinearizeVectorLoad final
 ///   %slice_1 = vector.extract %source[1] : vector<4xf32>
 ///   vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
 ///   ...
-/// This unrolls the 2D vector store into multiple 1D vector stores by extracting
-/// slices from the source vector and storing them into the destination.
-/// The pattern currently supports only 2D vectors
+/// This unrolls the 2D vector store into multiple 1D vector stores by
+/// extracting slices from the source vector and storing them into the
+/// destination. The pattern currently supports only 2D vectors
 struct LinearizeVectorStore final
     : public OpConversionPattern<vector::StoreOp> {
   using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
 
   LinearizeVectorStore(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
   matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
@@ -718,26 +719,26 @@ struct LinearizeVectorStore final
     llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
     Value xBaseIndex = indices[0];
 
-    auto vec = rewriter.create<vector::ShapeCastOp>(
-        loc, vecType, adaptor.getValueToStore());
+    auto vec = rewriter.create<vector::ShapeCastOp>(loc, vecType,
+                                                    adaptor.getValueToStore());
 
     for (auto i = 0; i < unrollCount; i++) {
       auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
       Value xIndex = xBaseIndex;
       if (i) {
         auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
-        xIndex =
-            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+        xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
       }
       indices[0] = xIndex;
       rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
-                                             indices);
+                                       indices);
     }
     rewriter.eraseOp(storeOp);
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the SplatOp to work on a linearized vector.
@@ -754,11 +755,11 @@ struct LinearizeVectorSplat final
   using OpConversionPattern::OpConversionPattern;
 
   LinearizeVectorSplat(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
@@ -766,12 +767,13 @@ struct LinearizeVectorSplat final
     auto dstTy = getTypeConverter()->convertType(splatOp.getType());
     if (!dstTy)
       return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
-    rewriter.replaceOpWithNewOp<vector::SplatOp>(
-        splatOp, adaptor.getInput(), dstTy);
+    rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
+                                                 dstTy);
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the CreateMaskOp to work on a
@@ -789,11 +791,11 @@ struct LinearizeVectorCreateMask final
   using OpConversionPattern::OpConversionPattern;
 
   LinearizeVectorCreateMask(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
   matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
@@ -816,8 +818,9 @@ struct LinearizeVectorCreateMask final
         createMaskOp, dstTy, adaptor.getOperands().back());
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts operations implementing the RegionBranchOpInterface
@@ -835,15 +838,14 @@ struct LinearizeRegionBranchOp final
       RegionBranchOpInterface>::OpInterfaceConversionPattern;
 
   LinearizeRegionBranchOp(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpInterfaceConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpInterfaceConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
-  matchAndRewrite(RegionBranchOpInterface op,
-                  ArrayRef<Value> operands,
+  matchAndRewrite(RegionBranchOpInterface op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto converter = getTypeConverter();
@@ -907,8 +909,9 @@ struct LinearizeRegionBranchOp final
     rewriter.finalizeOpModification(op);
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 } // namespace
@@ -939,9 +942,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   target.addLegalOp<mlir::vector::ShapeCastOp>();
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp, vector::LoadOp,
-                 vector::StoreOp, vector::CreateMaskOp,
-                 RegionBranchOpInterface, vector::SplatOp>(op) ||
+        if ((isa<vector::BitCastOp, vector::LoadOp, vector::StoreOp,
+                 vector::CreateMaskOp, RegionBranchOpInterface,
+                 vector::SplatOp>(op) ||
              op->hasTrait<OpTrait::ConstantLike>() ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -951,12 +954,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast, LinearizeVectorLoad,
-               LinearizeVectorStore, LinearizeVectorSplat,
-               LinearizeVectorCreateMask, LinearizeRegionBranchOp
-               >(typeConverter, patterns.getContext(),
-                                       targetBitWidth);
+  patterns
+      .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+           LinearizeVectorLoad, LinearizeVectorStore, LinearizeVectorSplat,
+           LinearizeVectorCreateMask, LinearizeRegionBranchOp>(
+          typeConverter, patterns.getContext(), targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
@@ -972,16 +974,16 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
       });
 
   target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
-    [=](vector::InsertStridedSliceOp op) -> bool {
-      if(isLessThanTargetBitWidth(op, targetBitWidth)) {
-        auto srcTy = op.getSourceVectorType();
-        auto dstTy = op.getDestVectorType();
-        if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
-            srcTy.hasStaticShape() && dstTy.hasStaticShape())
-          return false;
-      }
-      return true;
-    });
+      [=](vector::InsertStridedSliceOp op) -> bool {
+        if (isLessThanTargetBitWidth(op, targetBitWidth)) {
+          auto srcTy = op.getSourceVectorType();
+          auto dstTy = op.getDestVectorType();
+          if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
+              srcTy.hasStaticShape() && dstTy.hasStaticShape())
+            return false;
+        }
+        return true;
+      });
 
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,

``````````

</details>


https://github.com/llvm/llvm-project/pull/136193


More information about the Mlir-commits mailing list