[Mlir-commits] [mlir] [mlir][SME] Add vector.splat -> SME conversion (PR #67659)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Sep 28 05:46:16 PDT 2023


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/67659

>From 8fc8d4c5cb88689e1a0de0b48b31587b27d39f07 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 28 Sep 2023 11:19:05 +0000
Subject: [PATCH 1/3] [mlir][SME] Add vector.splat -> SME conversion

This conversion is identical to vector.broadcast when broadcasting a
scalar.
---
 .../VectorToArmSME/VectorToArmSME.cpp         | 60 ++++++++++++++++++-
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 38 ++++++++++++
 2 files changed, 97 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 264539b85c0ee23..a83c0e9cdafa521 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -240,6 +240,63 @@ struct BroadcastOpToArmSMELowering
   }
 };
 
+/// Conversion pattern for vector.splat.
+///
+/// Example:
+///
+///   %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
+///
+/// is converted to:
+///
+///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
+///   scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
+///     arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
+///       %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+///   }
+///
+/// This should, in practice, be identical to vector.broadcast when
+/// broadcasting a scalar.
+struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
+  using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::SplatOp splatOp,
+                                PatternRewriter &rewriter) const final {
+    auto tileType = splatOp.getResult().getType();
+    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+      return failure();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = splatOp.getLoc();
+
+    auto srcType = splatOp.getOperand().getType();
+    auto tileElementType = tileType.getElementType();
+
+    assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
+
+    // First, broadcast the scalar to a 1-d vector.
+    auto tileSliceType =
+        VectorType::get(tileType.getShape().drop_front(), tileElementType,
+                        /*scalableDims=*/{true});
+    Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
+        loc, tileSliceType, splatOp.getOperand());
+
+    arm_sme::CastTileToVector tile =
+        getSMETileAndCastToVector(rewriter, loc, tileType);
+
+    // Next, create a loop over ZA tile slices and "move" the generated 1-d
+    // vector to each slice.
+    auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
+    auto tileSliceIndex = forOp.getInductionVar();
+
+    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+        loc, tileType, broadcastOp1D, tile, tileSliceIndex);
+
+    rewriter.replaceOp(splatOp, tile);
+
+    return success();
+  }
+};
+
 /// Conversion pattern for vector.transpose.
 ///
 /// Stores the input tile to memory and reloads vertically.
@@ -319,5 +376,6 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
   patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
                VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
-               BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
+               BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+               TransposeOpToArmSMELowering>(&ctx);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index a64753578a1c861..3c08b1deafad27d 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -220,6 +220,44 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
   return
 }
 
+//===----------------------------------------------------------------------===//
+// vector.splat
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL:   func.func @splat_vec2d_from_i32(
+// CHECK-SAME:      %[[SRC:.*]]: i32) {
+// CHECK:   %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
+// CHECK:   %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK:   arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK:   %[[VSCALE:.*]] = vector.vscale
+// CHECK:   %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
+// CHECK:   scf.for {{.*}} to %[[UB]] {{.*}} {
+// CHECK:     arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @splat_vec2d_from_i32(%arg0: i32) {
+  %0 = vector.splat %arg0 : vector<[4]x[4]xi32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @splat_vec2d_from_f16(
+// CHECK-SAME:      %[[SRC:.*]]: f16) {
+// CHECK:   %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
+// CHECK:   %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK:   arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
+// CHECK:   %[[VSCALE:.*]] = vector.vscale
+// CHECK:   %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
+// CHECK:   scf.for {{.*}} to %[[UB]] {{.*}} {
+// CHECK:     arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
+func.func @splat_vec2d_from_f16(%arg0: f16) {
+  %0 = vector.splat %arg0 : vector<[8]x[8]xf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // vector.transpose
 //===----------------------------------------------------------------------===//

>From 01a92dafefbdc8164313d51e90cd2e71c3eaf444 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 28 Sep 2023 12:13:59 +0000
Subject: [PATCH 2/3] fixup! [mlir][SME] Add vector.splat -> SME conversion

Incorporate suggestions from Ben and Cullen
---
 mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp | 10 +++-------
 mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir       | 10 +++-------
 2 files changed, 6 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index a83c0e9cdafa521..3ffd5acb3f1e89d 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -254,8 +254,7 @@ struct BroadcastOpToArmSMELowering
 ///       %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
 ///   }
 ///
-/// This should, in practice, be identical to vector.broadcast when
-/// broadcasting a scalar.
+/// This is identical to vector.broadcast of a scalar.
 struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
   using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
 
@@ -265,7 +264,6 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
       return failure();
 
-    OpBuilder::InsertionGuard g(rewriter);
     auto loc = splatOp.getLoc();
 
     auto srcType = splatOp.getOperand().getType();
@@ -274,11 +272,9 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
     assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
 
     // First, broadcast the scalar to a 1-d vector.
-    auto tileSliceType =
-        VectorType::get(tileType.getShape().drop_front(), tileElementType,
-                        /*scalableDims=*/{true});
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
     Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
-        loc, tileSliceType, splatOp.getOperand());
+        loc, tileSliceType, splatOp.getInput());
 
     arm_sme::CastTileToVector tile =
         getSMETileAndCastToVector(rewriter, loc, tileType);
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 3c08b1deafad27d..3fcaf6cad335329 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -232,8 +232,8 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
 // CHECK:   %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
 // CHECK:   arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
 // CHECK:   %[[VSCALE:.*]] = vector.vscale
-// CHECK:   %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
-// CHECK:   scf.for {{.*}} to %[[UB]] {{.*}} {
+// CHECK:   %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
+// CHECK:   scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} {
 // CHECK:     arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
 func.func @splat_vec2d_from_i32(%arg0: i32) {
   %0 = vector.splat %arg0 : vector<[4]x[4]xi32>
@@ -246,11 +246,7 @@ func.func @splat_vec2d_from_i32(%arg0: i32) {
 // CHECK-LABEL:   func.func @splat_vec2d_from_f16(
 // CHECK-SAME:      %[[SRC:.*]]: f16) {
 // CHECK:   %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
-// CHECK:   %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK:   arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
-// CHECK:   %[[VSCALE:.*]] = vector.vscale
-// CHECK:   %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
-// CHECK:   scf.for {{.*}} to %[[UB]] {{.*}} {
+// CHECK:   scf.for
 // CHECK:     arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
 func.func @splat_vec2d_from_f16(%arg0: f16) {
   %0 = vector.splat %arg0 : vector<[8]x[8]xf16>

>From f41973d3029531492bca781a0a45be21b18cd030 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 28 Sep 2023 12:45:34 +0000
Subject: [PATCH 3/3] fixup! fixup! [mlir][SME] Add vector.splat -> SME
 conversion

Add missing insertion guard
---
 mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 3ffd5acb3f1e89d..d0e3dfc6ff9a0b2 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -264,6 +264,7 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
       return failure();
 
+    OpBuilder::InsertionGuard g(rewriter);
     auto loc = splatOp.getLoc();
 
     auto srcType = splatOp.getOperand().getType();



More information about the Mlir-commits mailing list