[Mlir-commits] [mlir] 447bb5b - [mlir][ArmSME] Introduce new lowering layer (Vector -> ArmSME)
Andrzej Warzynski
llvmlistbot at llvm.org
Tue Jul 18 01:06:58 PDT 2023
Author: Andrzej Warzynski
Date: 2023-07-18T08:04:59Z
New Revision: 447bb5bee402eab94987ebbd8f29d696f946ba9e
URL: https://github.com/llvm/llvm-project/commit/447bb5bee402eab94987ebbd8f29d696f946ba9e
DIFF: https://github.com/llvm/llvm-project/commit/447bb5bee402eab94987ebbd8f29d696f946ba9e.diff
LOG: [mlir][ArmSME] Introduce new lowering layer (Vector -> ArmSME)
At the moment, the lowering from the Vector dialect to SME looks like
this:
* Vector --> SME LLVM IR intrinsics
This patch introduces a new lowering layer between the Vector dialect
and the Arm SME extension:
* Vector --> ArmSME dialect (custom Ops) --> SME LLVM IR intrinsics.
This is motivated by 2 considerations:
1. Storing `ZA` to memory (e.g. `vector.transfer_write`) requires an
`scf.for` loop over all rows of `ZA`. Similar logic will apply to
"load to ZA from memory". This is a rather complex transformation and
a custom Op seems justified.
2. As discussed in [1], we need to prevent the LLVM type converter from
having to convert types unsupported in LLVM, e.g.
`vector<[16]x[16]xi8>`. A dedicated abstraction layer with custom Ops
opens a path to some fine tuning (e.g. custom type converters) that
will allow us to avoid this.
To facilitate this change, two new custom SME Op are introduced:
* `TileStoreOp`, and
* `ZeroOp`.
Note that no new functionality is added - these Ops merely model what's
already supported. In particular, the following tile size is assumed
(dimension and element size are fixed):
* `vector<[16]x[16]xi8>`
The new lowering layer is introduced via a conversion pass between the
Vector and the SME dialects. You can use the `-convert-vector-to-sme`
flag to run it. The following function:
```
func.func @example(%arg0 : memref<?x?xi8>) {
// (...)
%cst = arith.constant dense<0> : vector<[16]x[16]xi8>
vector.transfer_write %cst, %arg0 : vector<[16]x[16]xi8>, memref<?x?xi8>
return
}
```
would be lowered to:
```
func.func @example(%arg0: memref<?x?xi8>) {
// (...)
%0 = arm_sme.zero : vector<[16]x[16]xi8>
arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}
```
Later, a mechanism will be introduced to guarantee that `arm_sme.zero`
and `arm_sme.tile_store` operate on the same virtual tile. For `i8`
elements this is not required as there is only one tile.
In order to lower the above output to LLVM, use
* `-convert-vector-to-llvm="enable-arm-sme"`.
[1] https://github.com/openxla/iree/issues/14294
Reviewed By: WanderAway
Differential Revision: https://reviews.llvm.org/D154867
Added:
mlir/include/mlir/Conversion/VectorToArmSME/VectorToArmSME.h
mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSME/roundtrip.mlir
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
Removed:
mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp
mlir/test/Dialect/ArmSME/vector-ops.mlir
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index b15a60cfd005fb..21bc00c772875e 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -57,6 +57,7 @@
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
+#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fd648d838d29b3..767843b73098ab 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1076,6 +1076,20 @@ def ConvertVectorToSCF : Pass<"convert-vector-to-scf"> {
];
}
+//===----------------------------------------------------------------------===//
+// VectorToArmSME
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> {
+ let summary = "Lower the operations from the vector dialect into the ArmSME "
+ "dialect";
+ let description = [{
+ Pass that converts vector dialect operations into equivalent ArmSME dialect
+ operations.
+ }];
+ let dependentDialects = ["arm_sme::ArmSMEDialect"];
+}
+
//===----------------------------------------------------------------------===//
// VectorToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/VectorToArmSME/VectorToArmSME.h b/mlir/include/mlir/Conversion/VectorToArmSME/VectorToArmSME.h
new file mode 100644
index 00000000000000..2108e485dae7ba
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToArmSME/VectorToArmSME.h
@@ -0,0 +1,26 @@
+//===- VectorToArmSME.h - Convert vector to ArmSME dialect ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_VECTORTOARMSME_VECTORTOARMSME_H_
+#define MLIR_CONVERSION_VECTORTOARMSME_VECTORTOARMSME_H_
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTVECTORTOARMSME
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM
+/// intrinsics.
+void populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
+ MLIRContext &ctx);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOARMSME_VECTORTOARMSME_H_
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index dacf23ceca2de0..d1ed02abfd5c55 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -15,6 +15,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 19283155b21714..09f8bfb314a6e9 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -33,7 +33,7 @@ def ArmSME_Dialect : Dialect {
https://developer.arm.com/documentation/ddi0616
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
}];
- let dependentDialects = ["scf::SCFDialect"];
+ let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
}
//===----------------------------------------------------------------------===//
@@ -196,6 +196,64 @@ def GetTileID : ArmSME_Op<"get_tile_id", [Pure]> {
let assemblyFormat = "attr-dict `:` type($tile_id)";
}
+//
+// Tile reset.
+//
+
+def ZeroOp : ArmSME_Op<"zero", [Pure]> {
+ let summary = "Initialize the two-dimensional ZA array with 0s";
+ let results = (outs nxnxv16i8:$res);
+ let description = [{
+ Initialise ZA with 0. This operation is convenient wrapper for the SME
+ `zero` intrinsic and instruction.
+
+ NOTE: At the moment it is assumed that the element type is `i8` and that
+ there's only one "virtual tile".
+
+ Example:
+
+ ```mlir
+ %0 = arm_sme.zero : vector<[16]x[16]xi8>
+ ```
+ }];
+ let extraClassDeclaration = [{
+ VectorType getVectorType() {
+ return ::llvm::cast<VectorType>(getRes().getType());
+ }
+ }];
+ let assemblyFormat = "attr-dict `:` type($res)";
+}
+
+def TileStoreOp : ArmSME_Op<"tile_store"> {
+ let summary = "Tile store operation";
+ let description = [{
+ Store a 2D SME "virtual tile" to memory.
+
+ NOTE: At the moment it is assumed that the element type is `i8` and that
+ there's only one "virtual tile".
+
+ Example:
+
+ ```mlir
+ arm_sme.tile_store %0, %arg0[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+ ```
+ }];
+ let arguments = (ins nxnxv16i8:$valueToStore,
+ Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+ Variadic<Index>:$indices);
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return ::llvm::cast<MemRefType>(getBase().getType());
+ }
+ VectorType getVectorType() {
+ return ::llvm::cast<VectorType>(getValueToStore().getType());
+ }
+ }];
+
+ let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
+ "`:` type($base) `,` type($valueToStore)";
+}
+
//===----------------------------------------------------------------------===//
// ArmSME Intrinsic op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index a1c58f53e59862..33efa632872d32 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -47,6 +47,7 @@ add_subdirectory(TosaToArith)
add_subdirectory(TosaToLinalg)
add_subdirectory(TosaToSCF)
add_subdirectory(TosaToTensor)
+add_subdirectory(VectorToArmSME)
add_subdirectory(VectorToLLVM)
add_subdirectory(VectorToGPU)
add_subdirectory(VectorToSCF)
diff --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
new file mode 100644
index 00000000000000..b062f65e914e8b
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_conversion_library(MLIRVectorToArmSME
+ VectorToArmSME.cpp
+ VectorToArmSMEPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToArmSME
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArmSMEDialect
+ MLIRLLVMCommonConversion
+ )
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
new file mode 100644
index 00000000000000..cd0d99c5b5074f
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -0,0 +1,84 @@
+//===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+
+static constexpr unsigned kMinNumElts = 16;
+
+/// Returns true if 'val' is a splat of zero, false otherwise.
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+ if (llvm::isa<FloatType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+ if (llvm::isa<IntegerType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+ return false;
+}
+
+namespace {
+
+/// Look at `vector.transfer_write` operations and convert suitable candidates
+/// to ArmSME operations, e.g.:
+///
+/// %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
+/// vector.transfer_write %cst, %arg0 : vector<[16]x[16]xi8>, memref<?x?xi8>
+///
+/// is converted to:
+///
+/// %0 = arm_sme.zero : vector<[16]x[16]xi8>
+/// arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
+/// vector<[16]x[16]xi8>
+///
+struct TransferWriteToArmSMELowering
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const final {
+ auto vType = writeOp.getVectorType();
+ if (vType.getRank() != 2)
+ return failure();
+ if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
+ return failure();
+ if (vType.getElementType() != rewriter.getI8Type())
+ return failure();
+ if (vType.getScalableDims().size() != 2)
+ return failure();
+
+ auto loc = writeOp.getLoc();
+
+ if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
+ return failure();
+
+ auto constant = writeOp.getVector().getDefiningOp<arith::ConstantOp>();
+ if (!constant)
+ return failure();
+
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constant.getValueAttr());
+ if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
+ return failure();
+
+ auto zero = rewriter.create<arm_sme::ZeroOp>(loc, vType);
+
+ rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
+ writeOp, zero, writeOp.getSource(), writeOp.getIndices());
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
+ MLIRContext &ctx) {
+ patterns.add<TransferWriteToArmSMELowering>(&ctx);
+}
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
new file mode 100644
index 00000000000000..92025e9fbe82d3
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
@@ -0,0 +1,36 @@
+//===- VectorToArmSMEPass.cpp - Conversion from Vector to the ArmSME dialect =//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOARMSME
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+struct ConvertVectorToArmSMEPass
+ : public impl::ConvertVectorToArmSMEBase<ConvertVectorToArmSMEPass> {
+
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertVectorToArmSMEPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToArmSMEPatterns(patterns, getContext());
+
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 6ca7a7d84cfd80..acc4244ce9bb87 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -109,7 +109,6 @@ void LowerVectorToLLVMPass::runOnOperation() {
if (armSME) {
configureArmSMELegalizeForExportTarget(target);
populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
- arm_sme::populateVectorTransferLoweringPatterns(converter, patterns);
}
if (amx) {
configureAMXLegalizeForExportTarget(target);
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 5b30531bc29bb5..9b6332a478ade3 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRArmSMEDialect
MLIRLLVMDialect
MLIRSCFDialect
MLIRSideEffectInterfaces
+ MLIRVectorDialect
)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index b9a6bc4fba4530..9f4c3a0ce51a1e 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,7 +1,6 @@
add_mlir_dialect_library(MLIRArmSMETransforms
EnableArmStreaming.cpp
LegalizeForLLVMExport.cpp
- LowerVectorOps.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index aa59aa5b2b3585..cb556d8d4dfe6e 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -8,15 +8,20 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
using namespace mlir;
using namespace mlir::arm_sme;
+static constexpr unsigned kMinNumElts = 16;
+static constexpr unsigned kZeroZAMask = 255;
+
namespace {
/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
/// ops to enable the ZA storage array.
@@ -58,10 +63,104 @@ struct GetTileIDConversion : public ConvertOpToLLVMPattern<GetTileID> {
};
} // namespace
-void mlir::populateArmSMELegalizeForLLVMExportPatterns(
- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
-}
+/// Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return
+/// value. The latter is a nop, which should be folded away (e.g. during
+/// canonicalisation).
+///
+/// BEFORE:
+/// ```mlir
+/// %0 = arm_sme.zero : vector<[16]x[16]xi8>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %1 = arm_sme.get_tile_id : i8
+/// %2 = arm_sme.cast_tile_to_vector %1 : i8 to vector<[16]x[16]xi8>
+/// "arm_sme.intr.zero"(%c255_i32) : (i32) -> ()
+/// ```
+struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
+ using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(ZeroOp zero, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = zero.getLoc();
+
+ // Get Tile ID for the `zero` intrinsic.
+ // TODO: Map this to a valid `mask` for the `zero` intrinsic.
+ auto tileId = rewriter.create<arm_sme::GetTileID>(
+ loc, zero.getVectorType().getElementType());
+
+ // Create 'arm_sme.intr.zero' intrinsic to zero ZA.
+ // FIXME: Replace the hard-coded mask with a valid value based
+ // on `tileId`.
+ auto mask = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask));
+ rewriter.create<arm_sme::aarch64_sme_zero>(loc, mask);
+
+ // Create `CastTileToVectorOp` to use it as the output
+ rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(zero, zero.getType(),
+ tileId);
+
+ return success();
+ }
+};
+
+/// Lower 'arm_sme.store_tile' to a loop over the rows of ZA and store each row
+/// using 'arm_sme.intr.str'.
+///
+/// BEFORE:
+/// ```mlir
+/// arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
+/// vector<[16]x[16]xi8
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %vscale = "llvm.intr.vscale"() : () -> index
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %c16 = arith.constant 16 : index
+/// %vec_size = arith.muli %c16, %vscale : index
+/// scf.for %row_idx = %c0 to %vec_size step %c1 {
+/// // (...)
+/// "arm_sme.intr.str"(%row_idx, %addr) : (i32, !llvm.ptr) -> ()
+/// ```
+struct TileStoreOpConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
+ using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(TileStoreOp store, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = store.getLoc();
+
+ // Create loop that iterates from 0 to SVLB-1 inclusive (the number of
+ // vectors in ZA) and stores each ZA vector to memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minElems = rewriter.create<arith::ConstantIndexOp>(loc, kMinNumElts);
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto upperBound = rewriter.create<arith::MulIOp>(loc, minElems, vscale);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ // Create 'arm_sme.intr.str' intrinsic to store ZA vector.
+ auto vnumI64 = rewriter.create<arith::IndexCastUIOp>(
+ loc, rewriter.getI64Type(), forOp.getInductionVar());
+ auto offset =
+ rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
+ Value ptr =
+ getStridedElementPtr(loc, store.getMemRefType(), adaptor.getBase(),
+ ValueRange{vnumI64, offset}, rewriter);
+ auto vnumI32 = rewriter.create<arith::IndexCastUIOp>(
+ loc, rewriter.getI32Type(), forOp.getInductionVar());
+ rewriter.create<arm_sme::aarch64_sme_str>(loc, vnumI32, ptr);
+
+ rewriter.eraseOp(store);
+ return success();
+ }
+};
void mlir::configureArmSMELegalizeForExportTarget(
LLVMConversionTarget &target) {
@@ -95,3 +194,9 @@ void mlir::configureArmSMELegalizeForExportTarget(
return !funcOp->hasAttr("arm_za") || hasDisableZA;
});
}
+
+void mlir::populateArmSMELegalizeForLLVMExportPatterns(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
+ patterns.add<TileStoreOpConversion, ZeroOpConversion>(converter);
+}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp
deleted file mode 100644
index dfda09d2619e90..00000000000000
--- a/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp
+++ /dev/null
@@ -1,111 +0,0 @@
-//===- LowerVectorOps.cpp - Lower vector ops to SME -----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements rewrite patterns to lower vector dialect ops to ArmSME.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
-#include "mlir/Conversion/LLVMCommon/Pattern.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
-#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/PatternMatch.h"
-
-using namespace mlir;
-using namespace mlir::arm_sme;
-
-static constexpr unsigned kMinNumElts = 16;
-static constexpr unsigned kZeroZAMask = 255;
-
-/// Returns true if 'val' is a splat of zero, false otherwise.
-static bool isSplatZero(Type elemType, DenseElementsAttr val) {
- if (llvm::isa<FloatType>(elemType))
- return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
- if (llvm::isa<IntegerType>(elemType))
- return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
- return false;
-}
-
-namespace {
-/// Lower 'vector.transfer_write' op to 'arm_sme.intr.zero' op. Currently only
-/// supports 2d scalable vector type 'vector<[16x16]xi8>' that maps to the ZA0.B
-/// SME virtual tile. This will be extended to support more element types.
-struct TransferWriteToArmSMEZeroLowering
- : public ConvertOpToLLVMPattern<vector::TransferWriteOp> {
- using ConvertOpToLLVMPattern<vector::TransferWriteOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::TransferWriteOp write, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vType = write.getVectorType();
- if (vType.getRank() != 2)
- return failure();
- if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
- return failure();
- if (vType.getElementType() != rewriter.getI8Type())
- return failure();
- if (vType.getScalableDims().size() != 2)
- return failure();
-
- auto memRefType = llvm::dyn_cast<MemRefType>(write.getSource().getType());
- if (!memRefType)
- return failure();
-
- auto constant = write.getVector().getDefiningOp<arith::ConstantOp>();
- if (!constant)
- return failure();
-
- auto denseAttr = dyn_cast<DenseElementsAttr>(constant.getValueAttr());
- if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
- return failure();
-
- auto loc = write.getLoc();
-
- // Create 'arm_sme.intr.zero' intrinsic to zero ZA.
- auto tile = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask));
- rewriter.create<arm_sme::aarch64_sme_zero>(loc, tile);
-
- // Create loop that iterates from 0 to SVLB-1 inclusive (the number of
- // vectors in ZA) and stores each ZA vector to memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minElems = rewriter.create<arith::ConstantIndexOp>(loc, kMinNumElts);
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = rewriter.create<arith::MulIOp>(loc, minElems, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
- rewriter.setInsertionPointToStart(forOp.getBody());
-
- // Create 'arm_sme.intr.str' intrinsic to store ZA vector.
- auto vnumI64 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI64Type(), forOp.getInductionVar());
- auto offset =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getSource(),
- ValueRange{vnumI64, offset}, rewriter);
- auto vnumI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), forOp.getInductionVar());
- rewriter.create<arm_sme::aarch64_sme_str>(loc, vnumI32, ptr);
-
- rewriter.eraseOp(write);
-
- return success();
- }
-};
-} // namespace
-
-void mlir::arm_sme::populateVectorTransferLoweringPatterns(
- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<TransferWriteToArmSMEZeroLowering>(converter);
-}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 6256b5bc062d31..5c1f3a9e26db0d 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -183,3 +183,20 @@ func.func @arm_sme_get_tile_id_i128() -> i128 {
%0 = arm_sme.get_tile_id : i128
return %0 : i128
}
+
+// -----
+
+func.func @arm_sme_zero() -> () {
+ // CHECK: arm_sme.zero : vector<[16]x[16]xi8>
+ %0 = arm_sme.zero : vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) -> () {
+ // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
new file mode 100644
index 00000000000000..cb52ab5ff1f134
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: @transfer_write_2d_zero_i8
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
+// CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32
+// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
+// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK-DAG: %[[CAST_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index
+// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index
+// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] {
+// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64
+// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64
+// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32
+// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> ()
+func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
+ vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+ return
+}
+
diff --git a/mlir/test/Dialect/ArmSME/vector-ops.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
similarity index 57%
rename from mlir/test/Dialect/ArmSME/vector-ops.mlir
rename to mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 19b9896bc42a29..f3440e4fc61bf5 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,27 +1,13 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file | mlir-opt | FileCheck %s
-// CHECK-LABEL: @transfer_write_2d_zero_i8
-// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
-// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[C255:.*]] = arith.constant 255 : i32
-// CHECK-NEXT: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
-// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index
-// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
-// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
-// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index
-// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] {
-// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64
-// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64
-// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64
-// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
-// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32
-// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> ()
-func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
+
+// CHECK-LABEL: func.func @transfer_write_2d_zero(
+// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?xi8>) {
+func.func @transfer_write_2d_zero(%arg0 : memref<?x?xi8>) {
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8>
+// CHECK: arm_sme.tile_store %[[ZERO]], %[[ARG_0]][%[[C_0]], %[[C_0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+// CHECK: return
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<[16]x[16]xi8>
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -48,7 +34,7 @@ func.func @transfer_write_2d_zero__bad_type(%arg0 : memref<?x?xi4>) {
// CHECK-LABEL: @transfer_write_2d_zero__bad_shape
// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.intr.zero
+// CHECK-NOT: arm_sme.tile_store
func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<[8]x[8]xi8>
@@ -60,7 +46,7 @@ func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) {
// CHECK-LABEL: @transfer_write_2d_zero__bad_rank
// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.intr.zero
+// CHECK-NOT: arm_sme.tile_store
func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8>
@@ -72,7 +58,7 @@ func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) {
// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type
// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.intr.zero
+// CHECK-NOT: arm_sme.tile_store
func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> tensor<?x?xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<[16]x[16]xi8>
@@ -84,7 +70,7 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
// CHECK-LABEL: @transfer_write_2d_zero__non_zero_value
// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.intr.zero
+// CHECK-NOT: arm_sme.tile_store
func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<1> : vector<[16]x[16]xi8>
@@ -96,7 +82,7 @@ func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref<?x?xi8>) {
// CHECK-LABEL: @transfer_write_2d_zero__vec_unknown_defining_op
// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.intr.zero
+// CHECK-NOT: arm_sme.tile_store
func.func @transfer_write_2d_zero__vec_unknown_defining_op(%arg0 : memref<?x?xi8>, %arg1 : vector<[16]x[16]xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 70b53cfa8cf855..31a49a422192f7 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -enable-arm-streaming="mode=locally enable-za" \
// RUN: -convert-vector-to-llvm="enable-arm-sme" -test-lower-to-llvm | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \
More information about the Mlir-commits
mailing list