[Mlir-commits] [mlir] bd5d361 - [mlir][vector] add support for linearizing vector.bitcast in VectorLinearize (#123110)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 27 12:41:36 PST 2025


Author: Chao Chen
Date: 2025-01-27T14:41:33-06:00
New Revision: bd5d361c059814435bab24189e79e01d94c7039d

URL: https://github.com/llvm/llvm-project/commit/bd5d361c059814435bab24189e79e01d94c7039d
DIFF: https://github.com/llvm/llvm-project/commit/bd5d361c059814435bab24189e79e01d94c7039d.diff

LOG: [mlir][vector] add support for linearizing vector.bitcast in VectorLinearize (#123110)

This PR adds support for converting Vector::BitCastOp working on ND 
(N >1) vectors into the same op working on linearized (1D) vectors.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/test/Dialect/Vector/linearize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 68535ae5a7a5c6..3ecd585c5a26d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -72,13 +72,14 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
     auto resType =
         getTypeConverter()->convertType<VectorType>(constOp.getType());
 
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
     if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
       return rewriter.notifyMatchFailure(
           loc,
           "Cannot linearize a constant scalable vector that's not a splat");
 
-    if (!resType)
-      return rewriter.notifyMatchFailure(loc, "can't convert return type");
     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           loc, "Can't flatten since targetBitWidth <= OpSize");
@@ -459,6 +460,45 @@ struct LinearizeVectorInsert final
 private:
   unsigned targetVectorBitWidth;
 };
+
+/// This pattern converts the BitCastOp that works on nD (n > 1)
+/// vectors to a BitCastOp that works on linearized vectors.
+/// Following,
+///   vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
+/// is converted to :
+///   %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
+///   %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
+///   %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
+struct LinearizeVectorBitCast final
+    : public OpConversionPattern<vector::BitCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorBitCast(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+  LogicalResult
+  matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = castOp.getLoc();
+    auto resType = getTypeConverter()->convertType(castOp.getType());
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type.");
+
+    if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          loc, "Can't flatten since targetBitWidth <= OpSize");
+
+    rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
+                                                   adaptor.getSource());
+    return mlir::success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
@@ -485,7 +525,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<arith::ConstantOp>(op) ||
+        if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
                       ? typeConverter.isLegal(op)
@@ -494,8 +534,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  patterns.add<LinearizeConstant, LinearizeVectorizable>(
-      typeConverter, patterns.getContext(), targetBitWidth);
+  patterns
+      .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
+          typeConverter, patterns.getContext(), targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

diff  --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..99b1bbab1eede9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -179,7 +179,7 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
 
 // ALL-LABEL:   func.func @test_extract_strided_slice_1_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
-func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {  
+func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
   // ALL-NOT: vector.shuffle
   // ALL-NOT: vector.shape_cast
   // ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
@@ -318,3 +318,68 @@ func.func @test_vector_extract_scalar() {
   %0 = vector.extract %cst[0] : i32 from vector<4xi32>
   return
 }
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
+  // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x4xf32> to vector<16xf32>
+  // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<16xf32> to vector<32xf16>
+  // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<32xf16> to vector<4x8xf16>
+
+  // BW-128: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16>
+  // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16>
+  %1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+  return %1 : vector<4x8xf16>
+}
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ARG_0:.*]]: vector<4x2xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
+  // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32>
+  // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16>
+  // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16>
+  // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32>
+  // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16>
+  // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16>
+
+  // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x2xf32> to vector<4x4xf16>
+  %1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16>
+  return %1 : vector<4x4xf16>
+}
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ARG_0:.*]]: vector<4x[2]xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
+  // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32>
+  // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+  // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16>
+  // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32>
+  // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+  // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16>
+
+  // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x[2]xf32> to vector<4x[4]xf16>
+  %1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16>
+  return %1 : vector<4x[4]xf16>
+}
+
+// -----
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32>
+func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
+  // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32>
+  // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+  // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16>
+  // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32>
+  // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+  // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16>
+
+  // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<[4]x2xf32> to vector<[4]x4xf16>
+  %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
+  return %1 : vector<[4]x4xf16>
+}


        


More information about the Mlir-commits mailing list