[Mlir-commits] [mlir] [mlir][vector] add support for linearizing vector.bitcast in VectorLinearize (PR #123110)

Chao Chen llvmlistbot at llvm.org
Thu Jan 16 11:34:51 PST 2025


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/123110

>From d067a5ab18604996b9290e98701e3d8ac04efe7b Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 15 Jan 2025 19:10:38 +0000
Subject: [PATCH 1/4] add linearize pattern for bitcast

---
 .../Vector/Transforms/VectorLinearize.cpp     | 37 +++++++++++++++++--
 mlir/test/Dialect/Vector/linearize.mlir       | 21 ++++++++++-
 2 files changed, 54 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 68535ae5a7a5c6..b450ea91fef651 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -72,13 +72,14 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
     auto resType =
         getTypeConverter()->convertType<VectorType>(constOp.getType());
 
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
     if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
       return rewriter.notifyMatchFailure(
           loc,
           "Cannot linearize a constant scalable vector that's not a splat");
 
-    if (!resType)
-      return rewriter.notifyMatchFailure(loc, "can't convert return type");
     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           loc, "Can't flatten since targetBitWidth <= OpSize");
@@ -459,6 +460,35 @@ struct LinearizeVectorInsert final
 private:
   unsigned targetVectorBitWidth;
 };
+
+struct LinearizeVectorBitCast final
+    : public OpConversionPattern<vector::BitCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorBitCast(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+  LogicalResult
+  matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = castOp.getLoc();
+    auto resType = getTypeConverter()->convertType(castOp.getType());
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type.");
+
+    if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          loc, "Can't flatten since targetBitWidth <= OpSize");
+
+    rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType, adaptor.getSource());
+    return mlir::success();
+  }
+private:
+  unsigned targetVectorBitWidth;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
@@ -486,6 +516,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
         if ((isa<arith::ConstantOp>(op) ||
+             isa<vector::BitCastOp>(op) ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
                       ? typeConverter.isLegal(op)
@@ -494,7 +525,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  patterns.add<LinearizeConstant, LinearizeVectorizable>(
+  patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
       typeConverter, patterns.getContext(), targetBitWidth);
 }
 
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..0358c2637f72b2 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -179,7 +179,7 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
 
 // ALL-LABEL:   func.func @test_extract_strided_slice_1_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
-func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {  
+func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
   // ALL-NOT: vector.shuffle
   // ALL-NOT: vector.shape_cast
   // ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
@@ -318,3 +318,22 @@ func.func @test_vector_extract_scalar() {
   %0 = vector.extract %cst[0] : i32 from vector<4xi32>
   return
 }
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x1xf32>)
+func.func @test_vector_bitcast(%arg0: vector<4x1xf32>) -> vector<4x2xf16> {
+
+  // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
+  // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
+  // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+
+  // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
+  // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
+  // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+
+  // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4x2xf16>
+  %1 = vector.bitcast %arg0 : vector<4x1xf32> to vector<4x2xf16>
+  return %1 : vector<4x2xf16>
+}

>From 5f358831c7f25bc385a3d00f8a340766adbf1170 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 15 Jan 2025 19:36:22 +0000
Subject: [PATCH 2/4] code format

---
 .../Dialect/Vector/Transforms/VectorLinearize.cpp    | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b450ea91fef651..a89d2872e2434a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -482,9 +482,11 @@ struct LinearizeVectorBitCast final
       return rewriter.notifyMatchFailure(
           loc, "Can't flatten since targetBitWidth <= OpSize");
 
-    rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType, adaptor.getSource());
+    rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
+                                                   adaptor.getSource());
     return mlir::success();
   }
+
 private:
   unsigned targetVectorBitWidth;
 };
@@ -515,8 +517,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<arith::ConstantOp>(op) ||
-             isa<vector::BitCastOp>(op) ||
+        if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
                       ? typeConverter.isLegal(op)
@@ -525,8 +526,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
-      typeConverter, patterns.getContext(), targetBitWidth);
+  patterns
+      .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
+          typeConverter, patterns.getContext(), targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

>From f0b3ae0d45ebaea3659fed486a8cc973673814f6 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 16 Jan 2025 15:18:09 +0000
Subject: [PATCH 3/4] update tests

---
 mlir/test/Dialect/Vector/linearize.mlir | 37 +++++++++++++++++--------
 1 file changed, 25 insertions(+), 12 deletions(-)

diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 0358c2637f72b2..bab5c6c15ed8f1 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -322,18 +322,31 @@ func.func @test_vector_extract_scalar() {
 // -----
 
 // ALL-LABEL: test_vector_bitcast
-// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x1xf32>)
-func.func @test_vector_bitcast(%arg0: vector<4x1xf32>) -> vector<4x2xf16> {
-
-  // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
-  // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
-  // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x4xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
+  // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x4xf32> to vector<16xf32>
+  // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<16xf32> to vector<32xf16>
+  // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<32xf16> to vector<4x8xf16>
+
+  // BW-128: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16>
+  // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16>
+  %1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+  return %1 : vector<4x8xf16>
+}
 
-  // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
-  // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
-  // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+// -----
 
-  // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4x2xf16>
-  %1 = vector.bitcast %arg0 : vector<4x1xf32> to vector<4x2xf16>
-  return %1 : vector<4x2xf16>
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x2xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
+  // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32>
+  // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<8xf32> to vector<16xf16>
+  // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<16xf16> to vector<4x4xf16>
+  // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32>
+  // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<8xf32> to vector<16xf16>
+  // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<16xf16> to vector<4x4xf16>
+
+  // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x2xf32> to vector<4x4xf16>
+  %1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16>
+  return %1 : vector<4x4xf16>
 }

>From 25a8f39e7841d28c459a390175ea50cebf737b74 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 16 Jan 2025 19:34:32 +0000
Subject: [PATCH 4/4] add test for scalable vector

---
 .../Vector/Transforms/VectorLinearize.cpp     |  8 +++++
 mlir/test/Dialect/Vector/linearize.mlir       | 33 +++++++++++++++++++
 2 files changed, 41 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a89d2872e2434a..3ecd585c5a26d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -461,6 +461,14 @@ struct LinearizeVectorInsert final
   unsigned targetVectorBitWidth;
 };
 
+/// This pattern converts the BitCastOp that works on nD (n > 1)
+/// vectors to a BitCastOp that works on linearized vectors.
+/// Following,
+///   vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
+/// is converted to :
+///   %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
+///   %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
+///   %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
 struct LinearizeVectorBitCast final
     : public OpConversionPattern<vector::BitCastOp> {
   using OpConversionPattern::OpConversionPattern;
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index bab5c6c15ed8f1..de757fb9e4c1a5 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -350,3 +350,36 @@ func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
   %1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16>
   return %1 : vector<4x4xf16>
 }
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x[2]xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
+  // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32>
+  // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
+  // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<4x[4]xf16>
+  // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32>
+  // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
+  // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<4x[4]xf16>
+
+  // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<4x[4]xf16>
+  %1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16>
+  return %1 : vector<4x[4]xf16>
+}
+
+// -----
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ORIG_ARG:.*]]: vector<[4]x2xf32>
+func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
+  // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32>
+  // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
+  // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<[4]x4xf16>
+  // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32>
+  // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
+  // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<[4]x4xf16>
+
+  // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[4]x4xf16>
+  %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
+  return %1 : vector<[4]x4xf16>
+}



More information about the Mlir-commits mailing list