[Mlir-commits] [mlir] [mlir][vector] Add linearization pattern for vector.splat (PR #137651)

Nishant Patel llvmlistbot at llvm.org
Thu May 1 11:10:51 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/137651

>From 97a6c57217bab6a815ffcd9bd12905af4d5fca1a Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 25 Apr 2025 17:32:50 +0000
Subject: [PATCH 1/3] Add linearization pattern for vector.splat

---
 .../Vector/Transforms/VectorLinearize.cpp     | 63 ++++++++++++++++---
 mlir/test/Dialect/Vector/linearize.mlir       | 17 +++++
 2 files changed, 70 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..45c7e37738898 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -26,6 +26,9 @@
 
 using namespace mlir;
 
+constexpr unsigned defaultTargetVectorBitWidth =
+    std::numeric_limits<unsigned>::max();
+
 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
   auto resultTypes = op->getResultTypes();
   for (auto resType : resultTypes) {
@@ -82,7 +85,7 @@ struct LinearizeConstantLike final
 
   LinearizeConstantLike(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpTraitConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -136,7 +139,7 @@ struct LinearizeVectorizable final
 public:
   LinearizeVectorizable(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpTraitConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -175,7 +178,7 @@ struct LinearizeVectorExtractStridedSlice final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtractStridedSlice(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -289,7 +292,7 @@ struct LinearizeVectorShuffle final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorShuffle(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -362,13 +365,17 @@ struct LinearizeVectorExtract final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtract(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    // Skip if result is not a vector type
+    if (!isa<VectorType>(extractOp.getType()))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalar extract is not supported.");
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     if (!dstTy)
       return rewriter.notifyMatchFailure(extractOp,
@@ -425,7 +432,7 @@ struct LinearizeVectorInsert final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorInsert(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -506,7 +513,7 @@ struct LinearizeVectorBitCast final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorBitCast(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -531,12 +538,48 @@ struct LinearizeVectorBitCast final
   unsigned targetVectorBitWidth;
 };
 
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+///   vector.splat %value : vector<4x4xf32>
+/// is converted to:
+///   %out_1d = vector.splat %value : vector<16xf32>
+///   %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+    : public OpConversionPattern<vector::SplatOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorSplat(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+    rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
+                                                 dstTy);
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth) {
 
+  typeConverter.addConversion([](Type type) -> Type { return type; });
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
       return type;
@@ -557,7 +600,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp>(op) ||
+        if ((isa<vector::BitCastOp, vector::SplatOp>(op) ||
              op->hasTrait<OpTrait::ConstantLike>() ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -568,8 +611,8 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
       });
 
   patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
-                                       targetBitWidth);
+               LinearizeVectorBitCast, LinearizeVectorSplat>(
+      typeConverter, patterns.getContext(), targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..89f01abb79a74 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,20 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
   %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
   return %1 : vector<[4]x4xf16>
 }
+
+// -----
+// ALL-LABEL: linearize_vector_splat
+// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
+func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
+  // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+  // DEFAULT: return %[[CAST]] : vector<4x2xi32>
+  // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+  // BW-128: return %[[CAST]] : vector<4x2xi32>
+
+  // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32>
+  // BW-0: return %[[SPLAT]] : vector<4x2xi32>
+  %0 = vector.splat %arg0 : vector<4x2xi32>
+  return %0 : vector<4x2xi32>
+}

>From ff82c484ce1b2a7e3cc137c6c77b9253cd1b3f8a Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 1 May 2025 18:08:39 +0000
Subject: [PATCH 2/3] add newline

---
 mlir/test/Dialect/Vector/linearize.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index e3af10be7fd61..20169c15eb2c1 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -446,4 +446,4 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
   // BW-0: return %[[SPLAT]] : vector<4x[2]xi32>
   %0 = vector.splat %arg0 : vector<4x[2]xi32>
   return %0 : vector<4x[2]xi32>
-}
\ No newline at end of file
+}

>From 9b21851cfdc503bde21b0c2e83f37d228e314b3b Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 1 May 2025 18:10:36 +0000
Subject: [PATCH 3/3] Remove targetVectorBitWidth

---
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index dbed6d5a4cd75..c2c9c206dc2b2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -446,9 +446,6 @@ struct LinearizeVectorSplat final
                                                  dstTy);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 } // namespace



More information about the Mlir-commits mailing list