[Mlir-commits] [mlir] 9e1b825 - [mlir][ArmSME] Add conversion from ArmSME to SCF to materialize loops
Cullen Rhodes
llvmlistbot at llvm.org
Tue Aug 1 01:20:25 PDT 2023
Author: Cullen Rhodes
Date: 2023-08-01T08:20:02Z
New Revision: 9e1b8253214572fea7b7f90f9f8dff28f75444bc
URL: https://github.com/llvm/llvm-project/commit/9e1b8253214572fea7b7f90f9f8dff28f75444bc
DIFF: https://github.com/llvm/llvm-project/commit/9e1b8253214572fea7b7f90f9f8dff28f75444bc.diff
LOG: [mlir][ArmSME] Add conversion from ArmSME to SCF to materialize loops
Currently a loop is materialized when lowering ArmSME loads and stores
to intrinsics. This patch introduces two new ops to the ArmSME dialect
that map 1-1 with intrinsics:
1. arm_sme.load_tile_slice - Loads a 1D tile slice from
memory into a 2D SME "virtual tile".
2. arm_sme.store_tile_slice - Stores a 1D tile slice from a 2D SME
"virtual tile" into memory.
As well as a new conversion pass '-convert-arm-sme-to-scf' that
materializes loops with these ops. The existing load/store lowering to
intrinsics is updated to use these ops.
Depends on D156517
Discourse thread:
https://discourse.llvm.org/t/loop-materialization-in-armsme/72354
Reviewed By: awarzynski, dcaballe, WanderAway
Differential Revision: https://reviews.llvm.org/D156467
Added:
mlir/include/mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSME/roundtrip.mlir
mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h b/mlir/include/mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h
new file mode 100644
index 00000000000000..3a28ca11862afd
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h
@@ -0,0 +1,29 @@
+//===- ArmSMEToSCF.h - Convert ArmSME to SCF dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARMSMETOSCF_ARMSMETOSCF_H_
+#define MLIR_CONVERSION_ARMSMETOSCF_ARMSMETOSCF_H_
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTARMSMETOSCF
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Collect a set of patterns to convert from the ArmSME dialect to SCF.
+void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns);
+
+/// Create a pass to convert a subset of ArmSME ops to SCF.
+std::unique_ptr<Pass> createConvertArmSMEToSCFPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARMSMETOSCF_ARMSMETOSCF_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 71d5c4aa267ee1..014e976586af4c 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -14,6 +14,7 @@
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
+#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 04d65964478da6..9608d771a5dd52 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1108,6 +1108,22 @@ def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> {
let dependentDialects = ["arm_sme::ArmSMEDialect"];
}
+//===----------------------------------------------------------------------===//
+// ArmSMEToSCF
+//===----------------------------------------------------------------------===//
+
+def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
+ let summary = "Lower the operations from the ArmSME dialect into the SCF "
+ "dialect";
+ let constructor = "mlir::createConvertArmSMEToSCFPass()";
+ let dependentDialects = [
+ "scf::SCFDialect",
+ "arith::ArithDialect",
+ "vector::VectorDialect",
+ "arm_sme::ArmSMEDialect"
+ ];
+}
+
//===----------------------------------------------------------------------===//
// VectorToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 01925d023b7902..11b96f20acdfae 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -16,6 +16,7 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
//===----------------------------------------------------------------------===//
// ArmSME dialect definition
@@ -307,6 +308,102 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
"`:` type($base) `,` type($valueToStore)";
}
+def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
+ AllTypesMatch<["tile", "result"]>
+]> {
+ let summary = "Tile slice load and update operation";
+ let description = [{
+ Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
+ slice is defined by the dimension of the 2D scalable vector type pointed by
+ the index. A tile slice index describes where in the input tile the tile
+ slice is loaded to. The updated tile is returned as the result.
+
+ The slice of memory read is defined by a base and indices and must be
+ contiguous. The memref must be either rank 1 or rank 2, have dynamic
+ dimensions since the operation is scalable, and the element type must be a
+ scalar that matches the element type of the result.
+
+ Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
+ ```mlir
+ %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+ ```
+
+ Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
+ ```mlir
+ %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ ```
+
+ Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
+ ```mlir
+ %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ ```
+ }];
+ let arguments = (ins
+ Arg<AnyMemRef, "the reference to load from">:$base,
+ SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
+ let results = (outs SMETile:$result);
+
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return ::llvm::cast<MemRefType>(getBase().getType());
+ }
+ VectorType getVectorType() {
+ return ::llvm::cast<VectorType>(getResult().getType());
+ }
+ }];
+
+ let assemblyFormat = [{
+ $base `[` $indices `]` `,` $tile `,` $tile_slice_index
+ attr-dict `:` type($base) `,` type($result)
+ }];
+}
+
+def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
+ let summary = "Tile slice store operation";
+ let description = [{
+ Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
+ slice is defined by the dimension of the 2D scalable vector type pointed by
+ the index. A tile slice index describes where in the input tile the tile
+ slice is stored from.
+
+ The slice of memory written is defined by a base and indices and must be
+ contiguous. The memref must be either rank 1 or rank 2, have dynamic
+ dimensions since the operation is scalable, and the element type must be a
+ scalar that matches the element type of the input tile.
+
+ Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
+ ```mlir
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+ ```
+
+ Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
+ ```mlir
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ ```
+
+ Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
+ ```mlir
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+ ```
+ }];
+ let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
+ Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+ Variadic<Index>:$indices);
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return ::llvm::cast<MemRefType>(getBase().getType());
+ }
+ VectorType getVectorType() {
+ return ::llvm::cast<VectorType>(getTile().getType());
+ }
+ }];
+
+ let assemblyFormat = [{
+ $tile `,` $tile_slice_index `,` $base `[` $indices `]`
+ attr-dict `:` type($base) `,` type($tile)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ArmSME Intrinsic op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
new file mode 100644
index 00000000000000..e143726cf234f2
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -0,0 +1,187 @@
+//===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of ArmSME operations to SCF.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARMSMETOSCF
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice
+/// using `arm_sme.load_tile_slice`.
+///
+/// BEFORE:
+/// ```mlir
+/// %tile = arm_sme.tile_load %src[%c0, %c0] :
+/// memref<?x?xi32>, vector<[4]x[4]xi32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %tile_id = arm_sme.get_tile_id : i32
+/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+/// %vscale = vector.vscale
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %min_svl_s = arith.constant 4 : index
+/// %svl_s = arith.muli %min_svl_s, %vscale : index
+/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
+/// %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
+/// }
+/// ```
+struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+ PatternRewriter &rewriter) const override {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = tileLoadOp.getLoc();
+ auto tileType = tileLoadOp.getVectorType();
+ auto tileElementType = tileType.getElementType();
+ unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+ // Create 'arm_sme.get_tile' op.
+ auto tileId = rewriter.create<arm_sme::GetTileID>(
+ loc, rewriter.getIntegerType(tileElementWidth));
+
+ // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
+ // use as input tile to 'arm_sme.load_tile_slice' ops.
+ auto tile =
+ rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+
+ // Create a loop that loads each ZA tile slice from memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ auto forOp =
+ rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ auto tileSliceIndex = forOp.getInductionVar();
+ // TODO: use indices
+ // Create 'arm_sme.load_tile_slice' to load tile slice from
+ // memory into tile.
+ rewriter.create<arm_sme::LoadTileSliceOp>(
+ loc, tileType, tileLoadOp.getBase(), tile, tileSliceIndex,
+ tileSliceIndex);
+
+ rewriter.setInsertionPointAfter(forOp);
+
+ // Replace 'arm_sme.tile_load' with the tile.
+ rewriter.replaceOp(tileLoadOp, tile);
+
+ return success();
+ }
+};
+
+/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
+/// slice using `arm_sme.store_tile_slice`.
+///
+/// BEFORE:
+/// ```mlir
+/// arm_sme.tile_store %tile, %dest[%c0, %c0]
+/// : memref<?x?xi32>, vector<[4]x[4]xi32
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %vscale = vector.vscale
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %min_svl_s = arith.constant 4 : index
+/// %svl_s = arith.muli %min_svl_s, %vscale : index
+/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
+/// : memref<?x?xi32>, vector<[4]x[4]xi32>
+/// }
+/// ```
+struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
+ using OpRewritePattern<arm_sme::TileStoreOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
+ PatternRewriter &rewriter) const override {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = tileStoreOp.getLoc();
+ auto tileType = tileStoreOp.getVectorType();
+ auto tileElementType = tileType.getElementType();
+
+ // Create a loop that stores each ZA tile slice from memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ auto forOp =
+ rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ auto tileSliceIndex = forOp.getInductionVar();
+ // TODO: use indices
+ rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
+ tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
+ tileStoreOp.getBase(), tileSliceIndex);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<TileLoadOpConversion, TileStoreOpConversion>(
+ patterns.getContext());
+}
+
+namespace {
+
+struct ConvertArmSMEToSCFPass
+ : public impl::ConvertArmSMEToSCFBase<ConvertArmSMEToSCFPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ ConversionTarget target(getContext());
+ populateArmSMEToSCFConversionPatterns(patterns);
+ target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
+ arith::ArithDialect, scf::SCFDialect>();
+ target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createConvertArmSMEToSCFPass() {
+ return std::make_unique<ConvertArmSMEToSCFPass>();
+}
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
new file mode 100644
index 00000000000000..3bf4d7082afe42
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_conversion_library(MLIRArmSMEToSCF
+ ArmSMEToSCF.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSMEToSCF
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArmSMEDialect
+ MLIRArmSMEUtils
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 3ec4eb4382a91f..9fabeae0710383 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -4,6 +4,7 @@ add_subdirectory(ArithCommon)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
+add_subdirectory(ArmSMEToSCF)
add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexToLibm)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index b3a747a8fe8448..d61bde971647dc 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -138,68 +138,40 @@ Value getTileSlicePtrIndex(unsigned rank, Value tileSliceIndex,
llvm_unreachable("memref has unexpected rank!");
}
-/// Conversion pattern for `arm_sme.tile_load` to SME intrinsics.
-///
-/// Lower `arm_sme.tile_load` to a loop over the rows of ZA and load each row
-/// using `arm_sme.intr.ld1*.horiz`.
-///
-/// BEFORE:
-/// ```mlir
-/// %tile = arm_sme.tile_load %base[%c0, %c0] :
-/// memref<?x?xi32>, vector<[4]x[4]xi32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %tile_id = arm_sme.get_tile_id : i32
-/// %vscale = vector.vscale
-/// %c0 = arith.constant 0 : index
-/// %c1 = arith.constant 1 : index
-/// %min_svl_s = arith.constant 4 : index
-/// %svl_s = arith.muli %min_svl_s, %vscale : index
-/// scf.for %tile_slice = %c0 to %svl_s step %c1 {
-/// // (...)
-/// "arm_sme.intr.ld1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) :
-/// (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-/// }
-/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
-/// ```
-struct TileLoadToArmSMELowering
- : public ConvertOpToLLVMPattern<arm_sme::TileLoadOp> {
- using ConvertOpToLLVMPattern<arm_sme::TileLoadOp>::ConvertOpToLLVMPattern;
+/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
+struct LoadTileSliceToArmSMELowering
+ : public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
+ using ConvertOpToLLVMPattern<
+ arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
- arm_sme::TileLoadOp::Adaptor adaptor,
+ matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
+ arm_sme::LoadTileSliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto loc = tileLoadOp.getLoc();
- auto tileType = tileLoadOp.getVectorType();
+ auto loc = loadTileSliceOp.getLoc();
+ auto tileType = loadTileSliceOp.getVectorType();
auto tileElementType = tileType.getElementType();
unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
- // Create 'arm_sme.get_tile_id' op.
- auto tile = rewriter.create<arm_sme::GetTileID>(
- loc, rewriter.getIntegerType(tileElementWidth));
+ // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
+ // loaded to.
+ auto tile = rewriter.create<arm_sme::CastVectorToTile>(
+ loc, rewriter.getIntegerType(tileElementWidth),
+ loadTileSliceOp.getTile());
- // Create a loop that loads each ZA tile slice from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
// This describes both the number of ZA tile slices and the number of
// elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
// ..., SVL_Q).
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp =
- rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
- rewriter.setInsertionPointToStart(forOp.getBody());
// Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice.
- auto memRefType = tileLoadOp.getMemRefType();
- auto tileSlice = forOp.getInductionVar();
+ auto memRefType = loadTileSliceOp.getMemRefType();
+ auto tileSlice = loadTileSliceOp.getTileSliceIndex();
// TODO: The 'indices' argument for the 'base' memref is currently ignored,
// 'tileSliceIndex' should be added to 'indices[0]'.
Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
@@ -241,52 +213,27 @@ struct TileLoadToArmSMELowering
break;
}
- rewriter.setInsertionPointAfter(forOp);
-
// The load intrinsics have no result, replace 'arm_sme.tile_load' with
// 'arm_sme.cast_tile_to_vector' to preserve dataflow.
- rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(tileLoadOp, tileType,
- tile);
+ rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(loadTileSliceOp,
+ tileType, tile);
return success();
}
};
-/// Conversion pattern for `arm_sme.tile_store` to SME intrinsics.
-///
-/// Lower `arm_sme.tile_store` to a loop over the rows of ZA and store each row
-/// using `arm_sme.intr.st1*.horiz`.
-///
-/// BEFORE:
-/// ```mlir
-/// arm_sme.tile_store %value, %base[%c0, %c0] : memref<?x?xi32>,
-/// vector<[4]x[4]xi32
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xi32> to i32
-/// %vscale = vector.vscale
-/// %c0 = arith.constant 0 : index
-/// %c1 = arith.constant 1 : index
-/// %min_svl_s = arith.constant 4 : index
-/// %svl_s = arith.muli %min_svl_s, %vscale : index
-/// scf.for %tile_slice = %c0 to %svl_s step %c1 {
-/// // (...)
-/// "arm_sme.intr.st1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) :
-/// (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-/// }
-/// ```
-struct TileStoreToArmSMELowering
- : public ConvertOpToLLVMPattern<arm_sme::TileStoreOp> {
- using ConvertOpToLLVMPattern<arm_sme::TileStoreOp>::ConvertOpToLLVMPattern;
+/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
+struct StoreTileSliceToArmSMELowering
+ : public ConvertOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
+ using ConvertOpToLLVMPattern<
+ arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
- arm_sme::TileStoreOp::Adaptor adaptor,
+ matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
+ arm_sme::StoreTileSliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto loc = tileStoreOp.getLoc();
- auto tileType = tileStoreOp.getVectorType();
+ auto loc = storeTileSliceOp.getLoc();
+ auto tileType = storeTileSliceOp.getVectorType();
auto tileElementType = tileType.getElementType();
unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
@@ -294,27 +241,21 @@ struct TileStoreToArmSMELowering
// being stored.
auto tile = rewriter.create<arm_sme::CastVectorToTile>(
loc, rewriter.getIntegerType(tileElementWidth),
- tileStoreOp.getValueToStore());
+ storeTileSliceOp.getTile());
- // Create a loop that stores each ZA tile slice to memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
// This describes both the number of ZA tile slices and the number of
// elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
// ..., SVL_Q).
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp =
- rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
- rewriter.setInsertionPointToStart(forOp.getBody());
// Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
- auto memRefType = tileStoreOp.getMemRefType();
- auto tileSlice = forOp.getInductionVar();
+ auto memRefType = storeTileSliceOp.getMemRefType();
+ auto tileSlice = storeTileSliceOp.getTileSliceIndex();
// TODO: The 'indices' argument for the 'base' memref is currently ignored,
// 'tileSliceIndex' should be added to 'indices[0]'.
Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
@@ -340,19 +281,19 @@ struct TileStoreToArmSMELowering
llvm_unreachable("unexpected element type!");
case 8:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
- tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 16:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
- tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 32:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
- tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 64:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
- tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
}
@@ -403,6 +344,6 @@ void mlir::configureArmSMELegalizeForExportTarget(
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
- patterns.add<ZeroOpConversion, TileLoadToArmSMELowering,
- TileStoreToArmSMELowering>(converter);
+ patterns.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
+ LoadTileSliceToArmSMELowering>(converter);
}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
new file mode 100644
index 00000000000000..b64c79663038c0
--- /dev/null
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @arm_sme_tile_load(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[TILE_SLICE_INDEX]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_store(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[TILE_SLICE_INDEX]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_store(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
new file mode 100644
index 00000000000000..2c26c62ad42481
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -split-input-file | FileCheck %s
+
+// This test verifies the temporary casts that are emitted when lowering to
+// intrinsics to preserve data flow are correct. Canonicalization will remove
+// these.
+
+// CHECK-LABEL: @arm_sme_zero
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK: arm_sme.intr.zero
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK: scf.for
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
+// CHECK: "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.zero : vector<[16]x[16]xi8>
+ arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_tile_load
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK: scf.for
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
+// CHECK: "arm_sme.intr.ld1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK: }
+// CHECK: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
+func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return %tile : vector<[16]x[16]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_tile_store(
+// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
+// CHECK: scf.for
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
+// CHECK: "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_tile_store(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index a53be0d47d2cd2..022ae272c4a35a 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -351,3 +351,165 @@ func.func @arm_sme_tile_store_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 8d52ebb417b80d..402da661cfee44 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s
// CHECK-LABEL: @transfer_write_2d_zero_i8(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
@@ -8,13 +8,13 @@
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32
// CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
-// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
+// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64
@@ -35,6 +35,7 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
// CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
@@ -43,7 +44,7 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64
@@ -52,7 +53,6 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
// CHECK-NEXT: }
-// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
%c0 = arith.constant 0 : index
@@ -66,6 +66,7 @@ func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi8>)
// CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xi8> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
@@ -74,8 +75,11 @@ func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
+// CHECK-NEXT: %[[VSCALE_1:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK-NEXT: %[[VSCALE_IDX_1:.*]] = builtin.unrealized_conversion_cast %[[VSCALE_1]] : i64 to index
+// CHECK-NEXT: %[[SVL_B_1:.*]] = arith.muli %[[VSCALE_IDX_1]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
-// CHECK-NEXT: %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B]] : index to i64
+// CHECK-NEXT: %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B_1]] : index to i64
// CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE_I64]], %[[SVL_B_I64]] : i64
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
@@ -83,7 +87,6 @@ func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
// CHECK-NEXT: }
-// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
%c0 = arith.constant 0 : index
@@ -97,11 +100,11 @@ func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16
// CHECK-LABEL: @vector_load_i16(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi16>)
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16>
// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
// CHECK: arm_sme.intr.ld1h.horiz
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16>
func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
@@ -113,12 +116,12 @@ func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
// CHECK-LABEL: @vector_load_i32(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi32>)
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
// CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
// CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
// CHECK-NOT: arith.extui %[[TILE_ID]]
// CHECK-NOT: arith.trunci %[[TILE_ID]]
// CHECK: arm_sme.intr.ld1w.horiz
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -130,11 +133,11 @@ func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
// CHECK-LABEL: @vector_load_i64(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi64>)
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64>
// CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
// CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
// CHECK: arm_sme.intr.ld1d.horiz
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64>
func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
@@ -146,11 +149,11 @@ func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
// CHECK-LABEL: @vector_load_f16(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf16>)
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
// CHECK: arm_sme.intr.ld1h.horiz
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
@@ -162,11 +165,11 @@ func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
// CHECK-LABEL: @vector_load_bf16(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xbf16>)
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16>
// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
// CHECK: arm_sme.intr.ld1h.horiz
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16>
func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
@@ -178,12 +181,12 @@ func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
// CHECK-LABEL: @vector_load_f32(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>)
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32>
// CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
// CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
// CHECK-NOT: arith.extui %[[TILE_ID]]
// CHECK-NOT: arith.trunci %[[TILE_ID]]
// CHECK: arm_sme.intr.ld1w.horiz
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32>
func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -195,11 +198,11 @@ func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
// CHECK-LABEL: @vector_load_f64(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf64>)
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64>
// CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
// CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
// CHECK: arm_sme.intr.ld1d.horiz
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64>
func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
@@ -212,7 +215,6 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
// CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
@@ -221,7 +223,8 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64
@@ -242,9 +245,9 @@ func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>)
// CHECK-LABEL: @vector_store_i16(
// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xi16>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi16>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16
-// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16
// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
// CHECK: arm_sme.intr.st1h.horiz
func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>) {
@@ -258,9 +261,9 @@ func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>
// CHECK-LABEL: @vector_store_i32(
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi32>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-// CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
// CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
// CHECK-NOT: arith.extui %[[CAST_VECTOR_TO_TILE]]
// CHECK-NOT: arith.trunci %[[CAST_VECTOR_TO_TILE]]
// CHECK: arm_sme.intr.st1w.horiz
@@ -275,9 +278,9 @@ func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>
// CHECK-LABEL: @vector_store_i64(
// CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xi64>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi64>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64
-// CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
// CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64
// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
// CHECK: arm_sme.intr.st1d.horiz
func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>) {
@@ -291,9 +294,9 @@ func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>
// CHECK-LABEL: @vector_store_f16(
// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xf16>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf16>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16
-// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16
// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
// CHECK: arm_sme.intr.st1h.horiz
func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>) {
@@ -307,9 +310,9 @@ func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>
// CHECK-LABEL: @vector_store_bf16(
// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xbf16>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xbf16>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16
-// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16
// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
// CHECK: arm_sme.intr.st1h.horiz
func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref<?x?xbf16>) {
@@ -322,9 +325,9 @@ func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref<?x?xbf
// CHECK-LABEL: @vector_store_f32(
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
-// CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
// CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
// CHECK-NOT: arith.extui %[[CAST_VECTOR_TO_TILE]]
// CHECK-NOT: arith.trunci %[[CAST_VECTOR_TO_TILE]]
// CHECK: arm_sme.intr.st1w.horiz
@@ -339,9 +342,9 @@ func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>
// CHECK-LABEL: @vector_store_f64(
// CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xf64>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf64>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64
-// CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
// CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64
// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
// CHECK: arm_sme.intr.st1d.horiz
func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>) {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index f0db7529fd062e..0da4b1cf319e6d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
-// RUN: -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \
+// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index e2991d18a03a1c..082419ce05eba3 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
-// RUN: -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \
+// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN: -convert-vector-to-llvm="enable-arm-sme" \
// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \
More information about the Mlir-commits
mailing list