[Mlir-commits] [mlir] 3a492ab - [mlir][vector] Add linearization pattern for vector.splat (#137651)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 1 14:26:29 PDT 2025


Author: Nishant Patel
Date: 2025-05-01T14:26:26-07:00
New Revision: 3a492abf05521467bc882094272d03c3eb6251c4

URL: https://github.com/llvm/llvm-project/commit/3a492abf05521467bc882094272d03c3eb6251c4
DIFF: https://github.com/llvm/llvm-project/commit/3a492abf05521467bc882094272d03c3eb6251c4.diff

LOG: [mlir][vector] Add linearization pattern for vector.splat (#137651)

This PR is a breakdown [2 / 4] of the PR #136193 
The PR adds linearization patterns for vector.splat.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/test/Dialect/Vector/linearize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9fdede535112..b9cef003fa365 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -293,6 +293,10 @@ struct LinearizeVectorExtract final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    // Skip if result is not a vector type
+    if (!isa<VectorType>(extractOp.getType()))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalar extract is not supported.");
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     assert(dstTy && "expected 1-D vector type");
 
@@ -415,6 +419,32 @@ struct LinearizeVectorBitCast final
   }
 };
 
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+///   vector.splat %value : vector<4x4xf32>
+/// is converted to:
+///   %out_1d = vector.splat %value : vector<16xf32>
+///   %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+struct LinearizeVectorSplat final
+    : public OpConversionPattern<vector::SplatOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+    rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
+                                                 dstTy);
+    return success();
+  }
+};
+
 } // namespace
 
 /// Return true if the operation `op` does not support scalable vectors and
@@ -501,7 +531,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
     const TypeConverter &typeConverter, const ConversionTarget &target,
     RewritePatternSet &patterns) {
   patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast>(typeConverter, patterns.getContext());
+               LinearizeVectorBitCast, LinearizeVectorSplat>(
+      typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

diff  --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 06eaf58b225ae..20169c15eb2c1 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -413,3 +413,37 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
   %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
   return %1 : vector<[4]x4xf16>
 }
+
+// -----
+// ALL-LABEL: linearize_vector_splat
+// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
+func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
+  // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+  // DEFAULT: return %[[CAST]] : vector<4x2xi32>
+  // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+  // BW-128: return %[[CAST]] : vector<4x2xi32>
+
+  // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32>
+  // BW-0: return %[[SPLAT]] : vector<4x2xi32>
+  %0 = vector.splat %arg0 : vector<4x2xi32>
+  return %0 : vector<4x2xi32>
+}
+
+// -----
+// ALL-LABEL: linearize_scalable_vector_splat
+// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
+func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
+  // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
+  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
+  // DEFAULT: return %[[CAST]] : vector<4x[2]xi32>
+  // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
+  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
+  // BW-128: return %[[CAST]] : vector<4x[2]xi32>
+
+  // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x[2]xi32>
+  // BW-0: return %[[SPLAT]] : vector<4x[2]xi32>
+  %0 = vector.splat %arg0 : vector<4x[2]xi32>
+  return %0 : vector<4x[2]xi32>
+}


        


More information about the Mlir-commits mailing list