[Mlir-commits] [mlir] 10063c5 - [mlir][ArmSME] Move vector.print -> ArmSME lowering to VectorToArmSME (#74063)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 4 01:42:16 PST 2023
Author: Benjamin Maxwell
Date: 2023-12-04T09:42:11Z
New Revision: 10063c5a29b9dd5041ea7e0ab9569ed74ef54d32
URL: https://github.com/llvm/llvm-project/commit/10063c5a29b9dd5041ea7e0ab9569ed74ef54d32
DIFF: https://github.com/llvm/llvm-project/commit/10063c5a29b9dd5041ea7e0ab9569ed74ef54d32.diff
LOG: [mlir][ArmSME] Move vector.print -> ArmSME lowering to VectorToArmSME (#74063)
This moves the SME tile vector.print lowering from
`-convert-arm-sme-to-scf` to `-convert-vector-to-arm-sme`. This seems
like a more logical place, as this is lowering a vector op to ArmSME,
and it also prevents vector.print from blocking tile allocation.
Added:
Modified:
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 69c68663070b6..fece03040dbb8 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -447,75 +447,13 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
}
};
-/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
-/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
-/// a 1D `vector.print`.
-///
-/// BEFORE:
-/// ```mlir
-/// vector.print %tile : vector<[4]x[4]xf32>
-/// ```
-/// AFTER:
-/// ```mlir
-/// %c0 = arith.constant 0 : index
-/// %c1 = arith.constant 1 : index
-/// %c4 = arith.constant 4 : index
-/// %vscale = vector.vscale
-/// %svl_s = arith.muli %c4, %vscale : index
-/// scf.for %i = %c0 to %svl_s step %c1 {
-/// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
-/// : vector<[4]xf32> from vector<[4]x[4]xf32>
-/// vector.print %tile_slice : vector<[4]xf32>
-/// }
-/// ```
-struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
- using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::PrintOp printOp,
- PatternRewriter &rewriter) const override {
- if (!printOp.getSource())
- return failure();
-
- VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
- if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
- return failure();
-
- auto loc = printOp.getLoc();
-
- // Create a loop over the rows of the tile.
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
- auto minTileRows =
- rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
- {
- // Loop body.
- rewriter.setInsertionPointToStart(forOp.getBody());
- // Extract the current row from the tile.
- Value rowIndex = forOp.getInductionVar();
- // FIXME: Forward tile IDs.
- // For now, if you vector.print a SME tile you need to do
- // -allocate-arm-sme-tiles after -convert-arm-sme-to-scf.
- auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
- loc, printOp.getSource(), rowIndex);
- // Print the row with a 1D vector.print.
- rewriter.create<vector::PrintOp>(loc, tileSlice,
- printOp.getPunctuation());
- }
-
- rewriter.eraseOp(printOp);
- return success();
- }
-};
-
} // namespace
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
- patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
- TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
- TileVectorPrintOpConversion>(patterns.getContext());
+ patterns
+ .add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+ TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion>(
+ patterns.getContext());
}
namespace {
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 3016c7b0a8477..4b3fd26c6d59e 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -666,15 +666,75 @@ struct VectorInsertToArmSMELowering
}
};
+/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
+/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
+/// a 1D `vector.print`.
+///
+/// BEFORE:
+/// ```mlir
+/// vector.print %tile : vector<[4]x[4]xf32>
+/// ```
+/// AFTER:
+/// ```mlir
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %c4 = arith.constant 4 : index
+/// %vscale = vector.vscale
+/// %svl_s = arith.muli %c4, %vscale : index
+/// scf.for %i = %c0 to %svl_s step %c1 {
+/// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
+/// : vector<[4]xf32> from vector<[4]x[4]xf32>
+/// vector.print %tile_slice : vector<[4]xf32>
+/// }
+/// ```
+struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
+ using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::PrintOp printOp,
+ PatternRewriter &rewriter) const override {
+ if (!printOp.getSource())
+ return failure();
+
+ VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
+ if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
+ return failure();
+
+ auto loc = printOp.getLoc();
+
+ // Create a loop over the rows of the tile.
+ auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto minTileRows =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ {
+ // Loop body.
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ // Extract the current row from the tile.
+ Value rowIndex = forOp.getInductionVar();
+ auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+ loc, printOp.getSource(), rowIndex);
+ // Print the row with a 1D vector.print.
+ rewriter.create<vector::PrintOp>(loc, tileSlice,
+ printOp.getPunctuation());
+ }
+
+ rewriter.eraseOp(printOp);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
- SplatOpToArmSMELowering, TransferReadToArmSMELowering,
- TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
- VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
- VectorOuterProductToArmSMELowering,
- VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
- &ctx);
+ patterns
+ .add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
+ SplatOpToArmSMELowering, TransferReadToArmSMELowering,
+ TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
+ VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+ VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering,
+ VectorInsertToArmSMELowering, VectorPrintToArmSMELowering>(&ctx);
}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index fc28645a7acf7..efefc6c49e08f 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -160,25 +160,3 @@ func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest :
arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
-
-//===----------------------------------------------------------------------===//
-// vector.print
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
-{
- vector.print %tile : vector<[4]x[4]xf32>
- return
-}
-// CHECK-LABEL: func.func @arm_sme_tile_print(
-// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
-// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
-// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK-NEXT: %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
-// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 2491b2e2468cd..5bc147c60f3a6 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -736,3 +736,25 @@ func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64
%result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
}
+
+//===----------------------------------------------------------------------===//
+// vector.print
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @vector_print_tile(%tile: vector<[4]x[4]xf32>)
+{
+ vector.print %tile : vector<[4]x[4]xf32>
+ return
+}
+// CHECK-LABEL: func.func @vector_print_tile(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT: %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>
More information about the Mlir-commits
mailing list