[Mlir-commits] [mlir] [mlir][ArmSME] Add custom vector.print lowering for SME tiles (PR #66691)
Benjamin Maxwell
llvmlistbot at llvm.org
Mon Sep 25 05:31:05 PDT 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/66691
>From da0a15716d9b788fee432f15a2526deb60d3d3ed Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 20 Sep 2023 14:11:32 +0000
Subject: [PATCH 1/3] [mlir][ArmSME] Add custom vector.print lowering for SME
tiles
This adds a custom lowering for SME that loops over each row of the
tile, extracting it via an SME MOVA, then printing with a normal 1D
vector.print.
This makes writing SME integration tests easier and less verbose.
---
.../include/mlir/Dialect/ArmSME/Utils/Utils.h | 6 ++
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 93 ++++++++++++++++++-
.../Transforms/LegalizeForLLVMExport.cpp | 17 ----
mlir/lib/Dialect/ArmSME/Utils/Utils.cpp | 14 +++
.../ArmSMEToSCF/arm-sme-to-scf.mlir | 22 +++++
.../CPU/ArmSME/test-outerproduct-f32.mlir | 46 ++-------
.../CPU/ArmSME/test-outerproduct-f64.mlir | 26 +-----
.../Dialect/Vector/CPU/ArmSME/tile_fill.mlir | 28 +-----
8 files changed, 150 insertions(+), 102 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 9e8ad48b3c2db94..0941592497beaae 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -34,6 +34,12 @@ bool isValidSMETileElementType(Type type);
/// otherwise.
bool isValidSMETileVectorType(VectorType vType);
+/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
+/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
+/// integer, to an i32 that can be passed as the `tile` parameter to the SME
+/// intrinsics. Or returns `tile` if already i32.
+Value castTileIDToI32(Value tile, Location loc, RewriterBase &rewriter);
+
} // namespace arm_sme
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index b128165f75b9e81..e390ad09276e0f1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -190,11 +190,94 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
}
};
+/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
+/// extracting them via a MOVA, then printing with a 1D `vector.print`.
+///
+/// BEFORE:
+/// ```mlir
+/// vector.print %tile : vector<[4]x[4]xf32>
+/// ```
+/// AFTER:
+/// ```mlir
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %c4 = arith.constant 4 : index
+/// %ptrue = arith.constant dense<true> : vector<[4]xi1>
+/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xf32> to i32
+/// %vscale = vector.vscale
+/// %svl_s = arith.muli %c4, %vscale : index
+/// %cst = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+/// scf.for %i = %c0 to %svl_s step %c1 {
+/// %slice_idx = arith.index_cast %i : index to i32
+/// %tile_slice = "arm_sme.intr.read.horiz"
+/// (%cst, %ptrue, %tile_id, %slice_idx)
+/// : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+/// vector.print %tile_slice : vector<[4]xf32>
+/// }
+/// ```
+struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
+ using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::PrintOp printOp,
+ PatternRewriter &rewriter) const override {
+ if (!printOp.getSource())
+ return failure();
+
+ VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
+ if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
+ return failure();
+
+ auto loc = printOp.getLoc();
+
+ // Create an 'all true' predicate for each tile row.
+ auto predicateType =
+ VectorType::get(vectorType.getDimSize(1), rewriter.getI1Type(), true);
+ auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(predicateType, true));
+
+ // Cast tile to i32 tile ID.
+ auto tileId =
+ rewriter.create<arm_sme::CastVectorToTile>(loc, printOp.getSource());
+ auto tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
+
+ // Zero destination/fallback for tile slice extraction.
+ auto rowType = VectorType::get(vectorType.getDimSize(1),
+ vectorType.getElementType(), true);
+ auto zeroVector = rewriter.create<arith::ConstantOp>(
+ loc, rowType, rewriter.getZeroAttr(rowType));
+
+ // Create a loop over the rows of the tile.
+ auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto minTileRows =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ {
+ // Loop body.
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ // Extract the current row from the tile.
+ auto rowIndex = forOp.getInductionVar();
+ auto rowIndexI32 = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), rowIndex);
+ auto tileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
+ loc, rowType, zeroVector, allTruePredicate, tileIdI32, rowIndexI32);
+ // Print the row with a 1D vector.print.
+ rewriter.create<vector::PrintOp>(loc, tileSlice,
+ printOp.getPunctuation());
+ }
+
+ rewriter.eraseOp(printOp);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
- patterns.add<TileLoadOpConversion, TileStoreOpConversion>(
- patterns.getContext());
+ patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+ TileVectorPrintOpConversion>(patterns.getContext());
}
namespace {
@@ -208,6 +291,12 @@ struct ConvertArmSMEToSCFPass
target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
arith::ArithDialect, scf::SCFDialect>();
target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
+ target.addDynamicallyLegalOp<vector::PrintOp>([](vector::PrintOp op) {
+ if (!op.getSource())
+ return true;
+ VectorType vectorType = dyn_cast<VectorType>(op.getPrintType());
+ return !vectorType || !arm_sme::isValidSMETileVectorType(vectorType);
+ });
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 18147542e2bca73..945716eea57543d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -49,23 +49,6 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
}
};
-/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
-/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
-/// integer, to an i32 that can be passed as the `tile` parameter to the SME
-/// intrinsics. Or returns `tile` if already i32.
-Value castTileIDToI32(Value tile, Location loc,
- ConversionPatternRewriter &rewriter) {
- assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
- tile.getDefiningOp())) &&
- "expected ArmSME GetTileID or CastVectorToTile op!");
- unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
- if (tileElementWidth < 32)
- return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
- if (tileElementWidth > 32)
- return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
- return tile;
-}
-
/// Lower 'arm_sme.zero' to SME intrinsics.
///
/// BEFORE:
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index b8a47951cc7bbba..f17077ff8565d59 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
using namespace mlir;
@@ -42,3 +43,16 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
return true;
}
+
+Value mlir::arm_sme::castTileIDToI32(Value tile, Location loc,
+ RewriterBase &rewriter) {
+ assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
+ tile.getDefiningOp())) &&
+ "expected ArmSME GetTileID or CastVectorToTile op!");
+ unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
+ if (tileElementWidth < 32)
+ return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
+ if (tileElementWidth > 32)
+ return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
+ return tile;
+}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 95b51317cb0cf1b..b287a00171d9bd5 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -56,3 +56,25 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
+
+// -----
+
+func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
+{
+ vector.print %tile : vector<[4]x[4]xf32>
+ return
+}
+// CHECK-LABEL: func.func @arm_sme_tile_print(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
+// CHECK-DAG: %[[ZERO_VECTOR:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_cast %[[TILE_SLICE_INDEX]] : index to i32
+// CHECK-NEXT: %[[TILE_SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VECTOR]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 00f1f6fd3fa8e19..4265ca0f599281c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -16,7 +16,7 @@
llvm.func @printCString(!llvm.ptr<i8>)
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -25,7 +25,7 @@ func.func @printTileBegin() {
return
}
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -41,20 +41,8 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
%tile = vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>
- // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
- %vscale = vector.vscale
- %min_elts_s = arith.constant 4 : index
- %svl_s = arith.muli %min_elts_s, %vscale : index
- %za_s_size = arith.muli %svl_s, %svl_s : index
-
- // Allocate memory.
- %mem = memref.alloca(%za_s_size) : memref<?xf32>
-
- // Store the tile to memory.
- vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
-
- // Reload and print. The smallest SVL is 128-bits so the tile will be at
- // least 4x4xf32.
+ // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+ // 4x4xf32.
//
// WITHOUT-ACC: TILE BEGIN
// WITHOUT-ACC-NEXT: ( 0, 0, 0, 0
@@ -63,10 +51,7 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
// WITHOUT-ACC-NEXT: ( 0, 3, 6, 9
// WITHOUT-ACC: TILE END
func.call @printTileBegin() : () -> ()
- scf.for %i = %c0 to %za_s_size step %svl_s {
- %tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
- vector.print %tileslice : vector<[4]xf32>
- }
+ vector.print %tile : vector<[4]x[4]xf32>
func.call @printTileEnd() : () -> ()
return
@@ -81,20 +66,8 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
%tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
- // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
- %vscale = vector.vscale
- %min_elts_s = arith.constant 4 : index
- %svl_s = arith.muli %min_elts_s, %vscale : index
- %za_s_size = arith.muli %svl_s, %svl_s : index
-
- // Allocate memory.
- %mem = memref.alloca(%za_s_size) : memref<?xf32>
-
- // Store the tile to memory.
- vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
-
- // Reload and print. The smallest SVL is 128-bits so the tile will be at
- // least 4x4xf32.
+ // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+ // 4x4xf32.
//
// WITH-ACC: TILE BEGIN
// WITH-ACC-NEXT: ( 10, 10, 10, 10
@@ -103,10 +76,7 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
// WITH-ACC-NEXT: ( 10, 13, 16, 19
// WITH-ACC: TILE END
func.call @printTileBegin() : () -> ()
- scf.for %i = %c0 to %za_s_size step %svl_s {
- %tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
- vector.print %tileslice : vector<[4]xf32>
- }
+ vector.print %tile : vector<[4]x[4]xf32>
func.call @printTileEnd() : () -> ()
return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 2c2a06fa8db26e1..cb2c6b98a4eef3a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -13,7 +13,7 @@
llvm.func @printCString(!llvm.ptr<i8>)
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
return
}
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -32,7 +32,6 @@ func.func @printTileEnd() {
}
func.func @test_outerproduct_with_accumulator_2x2xf64() {
- %c0 = arith.constant 0 : index
%f1 = arith.constant 1.0 : f64
%f2 = arith.constant 2.0 : f64
%f10 = arith.constant 10.0 : f64
@@ -44,30 +43,15 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {
%tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64>
- // Calculate the size of a 64-bit tile, e.g. ZA{n}.d.
- %vscale = vector.vscale
- %min_elts_d = arith.constant 2 : index
- %svl_d = arith.muli %min_elts_d, %vscale : index
- %za_d_size = arith.muli %svl_d, %svl_d : index
-
- // Allocate memory.
- %mem = memref.alloca(%za_d_size) : memref<?xf64>
-
- // Store the tile to memory.
- vector.store %tile, %mem[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
-
- // Reload and print. The smallest SVL is 128-bits so the tile will be at
- // least 2x2xf64.
+ // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+ // 2x2xf64.
//
// CHECK: TILE BEGIN
// CHECK-NEXT: ( 12, 12
// CHECK-NEXT: ( 12, 12
// CHECK: TILE END
func.call @printTileBegin() : () -> ()
- scf.for %i = %c0 to %za_d_size step %svl_d {
- %tileslice = vector.load %mem[%i] : memref<?xf64>, vector<[2]xf64>
- vector.print %tileslice : vector<[2]xf64>
- }
+ vector.print %tile : vector<[2]x[2]xf64>
func.call @printTileEnd() : () -> ()
return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index a407b13b541839f..fe6ded71c1613fa 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -13,7 +13,7 @@
llvm.func @printCString(!llvm.ptr<i8>)
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
return
}
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -32,29 +32,12 @@ func.func @printTileEnd() {
}
func.func @entry() -> i32 {
- %c0 = arith.constant 0 : index
- %c1_index = arith.constant 1 : index
-
- %min_elts_s = arith.constant 4 : index
- %vscale = vector.vscale
-
- // "svl" refers to the Streaming Vector Length and "svl_s" the number of
- // 32-bit elements in a vector of SVL bits.
- %svl_s = arith.muli %min_elts_s, %vscale : index
-
- // Allocate memory.
- %tilesize = arith.muli %svl_s, %svl_s : index
- %mem = memref.alloca(%tilesize) : memref<?xi32>
-
// Fill a tile with '123'. This will get lowered to a 1-d vector splat of
// '123' and a loop that writes this vector to each tile slice in the ZA
// tile.
%tile = arith.constant dense<123> : vector<[4]x[4]xi32>
- // Store tile to memory so it can be dumped.
- vector.store %tile, %mem[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-
- // Dump "mem". The smallest SVL is 128-bits so the tile will be at least
+ // Print the tile. The smallest SVL is 128-bits so the tile will be at least
// 4x4xi32.
//
// CHECK: TILE BEGIN
@@ -64,10 +47,7 @@ func.func @entry() -> i32 {
// CHECK-NEXT: ( 123, 123, 123, 123
// CHECK: TILE END
func.call @printTileBegin() : () -> ()
- scf.for %i = %c0 to %tilesize step %svl_s {
- %tileslice = vector.load %mem[%i] : memref<?xi32>, vector<[4]xi32>
- vector.print %tileslice : vector<[4]xi32>
- }
+ vector.print %tile : vector<[4]x[4]xi32>
func.call @printTileEnd() : () -> ()
%c0_i32 = arith.constant 0 : i32
>From bfede7a34e36ec28cb0842acfd1856d03b2abf84 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 25 Sep 2023 12:11:30 +0000
Subject: [PATCH 2/3] Review fixup
---
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index e390ad09276e0f1..d084373439ab6bf 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -238,7 +238,7 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
// Cast tile to i32 tile ID.
auto tileId =
rewriter.create<arm_sme::CastVectorToTile>(loc, printOp.getSource());
- auto tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
+ Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
// Zero destination/fallback for tile slice extraction.
auto rowType = VectorType::get(vectorType.getDimSize(1),
@@ -258,7 +258,7 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
// Loop body.
rewriter.setInsertionPointToStart(forOp.getBody());
// Extract the current row from the tile.
- auto rowIndex = forOp.getInductionVar();
+ Value rowIndex = forOp.getInductionVar();
auto rowIndexI32 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), rowIndex);
auto tileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
>From 387684bdfff26bb79d85bc4191003524db3045bf Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 25 Sep 2023 12:27:42 +0000
Subject: [PATCH 3/3] Update recent tests
---
.../Vector/CPU/ArmSME/test-load-vertical.mlir | 15 ++--------
.../Vector/CPU/ArmSME/test-transpose.mlir | 28 ++++++-------------
2 files changed, 11 insertions(+), 32 deletions(-)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index d57001daf855f3d..8c7d8c954d38475 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -13,7 +13,7 @@
llvm.func @printCString(!llvm.ptr<i8>)
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
return
}
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -44,7 +44,6 @@ func.func @entry() {
// Allocate memory.
%mem1 = memref.alloca(%za_s_size) : memref<?xi32>
- %mem2 = memref.alloca(%za_s_size) : memref<?xi32>
// Fill each "row" of "mem1" with row number.
//
@@ -66,11 +65,6 @@ func.func @entry() {
// Load tile from "mem1" vertically.
%0 = arm_sme.tile_load %mem1[%c0, %c0], <vertical> : memref<?xi32>, vector<[4]x[4]xi32>
- // Store tile back to "mem2" to print.
- // TODO: Support vector.print for 2-D scalable vectors so don't have to spill
- // to memory and reload to print.
- vector.store %0, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-
// 1. ORIGINAL HORIZONTAL LAYOUT
// Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
// 4x4xi32.
@@ -99,10 +93,7 @@ func.func @entry() {
// CHECK-NEXT: ( 0, 1, 2, 3
// CHECK: TILE END
func.call @printTileBegin() : () -> ()
- scf.for %i = %c0 to %za_s_size step %svl_s {
- %tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
- vector.print %tileslice : vector<[4]xi32>
- }
+ vector.print %0 : vector<[4]x[4]xi32>
func.call @printTileEnd() : () -> ()
return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index 4350abbd13eca75..4bb9258098d98fd 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -13,7 +13,7 @@
llvm.func @printCString(!llvm.ptr<i8>)
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
return
}
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -44,7 +44,6 @@ func.func @entry() {
// Allocate memory.
%mem1 = memref.alloca(%za_s_size) : memref<?xi32>
- %mem2 = memref.alloca(%za_s_size) : memref<?xi32>
// Fill each "row" of "mem1" with row number.
//
@@ -69,13 +68,8 @@ func.func @entry() {
// Transpose tile.
%transposed_tile = vector.transpose %tile, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
- // Store tile back to "mem2" to print.
- // TODO: Replace this with vector.print when
- // https://github.com/llvm/llvm-project/pull/66691 lands.
- vector.store %transposed_tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-
- // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
- // 4x4xi32.
+ // Dump the original tile. The smallest SVL is 128-bits so the tile will be at
+ // least 4x4xi32.
//
// CHECK: TILE BEGIN
// CHECK-NEXT: ( 0, 0, 0, 0
@@ -84,14 +78,11 @@ func.func @entry() {
// CHECK-NEXT: ( 3, 3, 3, 3
// CHECK: TILE END
func.call @printTileBegin() : () -> ()
- scf.for %i = %c0 to %za_s_size step %svl_s {
- %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
- vector.print %tileslice : vector<[4]xi32>
- }
+ vector.print %tile : vector<[4]x[4]xi32>
func.call @printTileEnd() : () -> ()
- // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
- // 4x4xi32.
+ // Dump the transposed tile. The smallest SVL is 128-bits so the tile will be
+ // at least 4x4xi32.
//
// CHECK: TILE BEGIN
// CHECK-NEXT: ( 0, 1, 2, 3
@@ -100,10 +91,7 @@ func.func @entry() {
// CHECK-NEXT: ( 0, 1, 2, 3
// CHECK: TILE END
func.call @printTileBegin() : () -> ()
- scf.for %i = %c0 to %za_s_size step %svl_s {
- %tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
- vector.print %tileslice : vector<[4]xi32>
- }
+ vector.print %transposed_tile : vector<[4]x[4]xi32>
func.call @printTileEnd() : () -> ()
return
More information about the Mlir-commits
mailing list