[Mlir-commits] [mlir] 2dd3f42 - [mlir][ArmSME] Lower vector.broadcast to ArmSME

Cullen Rhodes llvmlistbot at llvm.org
Tue Aug 29 02:43:33 PDT 2023


Author: Cullen Rhodes
Date: 2023-08-29T09:43:16Z
New Revision: 2dd3f42083573e8a123b2e1daf8dc4587e278177

URL: https://github.com/llvm/llvm-project/commit/2dd3f42083573e8a123b2e1daf8dc4587e278177
DIFF: https://github.com/llvm/llvm-project/commit/2dd3f42083573e8a123b2e1daf8dc4587e278177.diff

LOG: [mlir][ArmSME] Lower vector.broadcast to ArmSME

This adds support for lowering vector.broadcast ops to SME, if the
source is either a scalar, 0-d vector, or 1-d vector, and the result a
2-d scalable vector that aligns with SME tiles.

This follows on from D157005 which introduced a vector to tile slice op
that moves a 1-d scalable vector to a slice of a 2-d scalable vector
(tile). The lowering from vector.broadcast is similar, a couple of
helper functions are added to prevent duplication.

Lowering of vector.broadcast contributes towards a path from linalg.fill
to SME.

Depends on D157005

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D158586

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index aedb0bbde4858a..0a1a087d9c8d6c 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -24,6 +24,38 @@ static bool isSplatZero(Type elemType, DenseElementsAttr val) {
   return false;
 }
 
+/// Generates a for loop over ZA tile slices where the induction variable is
+/// the tile slice index.
+static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+                                        Type eltType) {
+  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+      loc, arm_sme::getSMETileSliceMinNumElts(eltType));
+  auto vscale =
+      rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto numTileSlices =
+      rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+  auto forOp =
+      rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+  rewriter.setInsertionPointToStart(forOp.getBody());
+  return forOp;
+}
+
+/// Returns a tile of the given vector type.
+static arm_sme::CastTileToVector
+getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
+                          VectorType type) {
+  unsigned tileElementWidth = type.getElementType().getIntOrFloatBitWidth();
+
+  // Create 'arm_sme.get_tile' op.
+  auto tileId = rewriter.create<arm_sme::GetTileID>(
+      loc, rewriter.getIntegerType(tileElementWidth));
+
+  // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type.
+  return rewriter.create<arm_sme::CastTileToVector>(loc, type, tileId);
+}
+
 namespace {
 
 /// Conversion pattern for vector.transfer_write.
@@ -122,29 +154,10 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
         tileSliceType, denseAttr.getSplatValue<Attribute>());
     auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
 
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
-
-    // Create 'arm_sme.get_tile' op.
-    auto tileId = rewriter.create<arm_sme::GetTileID>(
-        loc, rewriter.getIntegerType(tileElementWidth));
-
-    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
-    // use as input tile to 'arm_sme.move_vector_to_tile_slice' ops.
-    auto tile =
-        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
-
-    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
-    auto vscale =
-        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
-    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    auto numTileSlices =
-        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    // Create a loop that broadcasts the constant to each ZA tile slice.
-    auto forOp =
-        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
-    rewriter.setInsertionPointToStart(forOp.getBody());
+    arm_sme::CastTileToVector tile =
+        getSMETileAndCastToVector(rewriter, loc, tileType);
+
+    auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
     auto tileSliceIndex = forOp.getInductionVar();
 
     // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice.
@@ -159,10 +172,78 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
   }
 };
 
+/// Conversion pattern for vector.broadcast.
+///
+/// Example:
+///
+///   %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
+///
+/// is converted to:
+///
+///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
+///   scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
+///     arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
+///       %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+///   }
+///
+/// Supports scalar, 0-d vector, and 1-d vector broadcasts.
+struct BroadcastOpToArmSMELowering
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const final {
+    auto tileType = broadcastOp.getResultVectorType();
+    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+      return failure();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = broadcastOp.getLoc();
+
+    auto srcType = broadcastOp.getSourceType();
+    auto srcVectorType = dyn_cast<VectorType>(srcType);
+    auto tileElementType = tileType.getElementType();
+
+    Value broadcastOp1D;
+    if (srcType.isIntOrFloat() ||
+        (srcVectorType && (srcVectorType.getRank() == 0))) {
+      // Broadcast scalar or 0-d vector to 1-d vector.
+      auto tileSliceType =
+          VectorType::get(tileType.getShape().drop_front(), tileElementType,
+                          /*scalableDims=*/{true});
+      broadcastOp1D = rewriter.create<vector::BroadcastOp>(
+          loc, tileSliceType, broadcastOp.getSource());
+    } else if (srcVectorType && (srcVectorType.getRank() == 1))
+      // Value to broadcast is already a 1-d vector, nothing to do.
+      broadcastOp1D = broadcastOp.getSource();
+    else
+      return failure();
+
+    arm_sme::CastTileToVector tile =
+        getSMETileAndCastToVector(rewriter, loc, tileType);
+
+    // Create a loop over ZA tile slices.
+    auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
+    auto tileSliceIndex = forOp.getInductionVar();
+
+    // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value to each
+    // tile slice.
+    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+        loc, tileType, broadcastOp1D, tile, tileSliceIndex);
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    rewriter.replaceOp(broadcastOp, tile);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
   patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
-               VectorStoreToArmSMELowering, ConstantOpToArmSMELowering>(&ctx);
+               VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
+               BroadcastOpToArmSMELowering>(&ctx);
 }

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 76e601da73738d..8b6bd8f52d1900 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -153,3 +153,54 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
   %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
   return %0 : tensor<?x?xi8>
 }
+
+// =============================================================================
+// vector.broadcast
+// =============================================================================
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vec2d_from_i32(
+// CHECK-SAME:                                        %[[SRC:.*]]: i32) {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK:   %[[C10:.*]] = arm_sme.move_vector_to_tile_slice %[[SRC_1D]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[4]x[4]xi32>) -> ()
+func.func @broadcast_vec2d_from_i32(%arg0: i32) {
+  %0 = vector.broadcast %arg0 : i32 to vector<[4]x[4]xi32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vec2d_from_vec0d(
+// CHECK-SAME:                                          %[[SRC:.*]]: vector<f32>) {
+// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : vector<f32> to vector<[4]xf32>
+// CHECK: scf.for
+// CHECK:   arm_sme.move_vector_to_tile_slice %[[SRC_1D]], {{.*}}
+func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) {
+  %0 = vector.broadcast %arg0 : vector<f32> to vector<[4]x[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vec2d_from_vec1d(
+// CHECK-SAME:                                          %[[SRC:.*]]: vector<[8]xi16>) {
+// CHECK-NOT: vector.broadcast
+// CHECK: scf.for
+// CHECK:   arm_sme.move_vector_to_tile_slice %[[SRC]], {{.*}}
+func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
+  %0 = vector.broadcast %arg0 : vector<[8]xi16> to vector<[8]x[8]xi16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list