[Mlir-commits] [mlir] 564713c - [mlir][ArmSME] Add basic lowering of vector.transfer_write to zero
Cullen Rhodes
llvmlistbot at llvm.org
Mon Jul 3 03:26:03 PDT 2023
Author: Cullen Rhodes
Date: 2023-07-03T10:18:43Z
New Revision: 564713c47175d9f61fe8e18e750fb7f7b486d533
URL: https://github.com/llvm/llvm-project/commit/564713c47175d9f61fe8e18e750fb7f7b486d533
DIFF: https://github.com/llvm/llvm-project/commit/564713c47175d9f61fe8e18e750fb7f7b486d533.diff
LOG: [mlir][ArmSME] Add basic lowering of vector.transfer_write to zero
This patch adds support for lowering a 'vector.transfer_write' of zeroes
and type 'vector<[16x16]xi8>' to the SME 'zero {za}' instruction [1],
which zeroes the entire accumulator, and then writing it out to memory
with the 'str' instruction [2].
This contributes to supporting a path from 'linalg.fill' to SME.
[1] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
[2] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/STR--Store-vector-from-ZA-array-
Reviewed By: awarzynski, dcaballe, WanderAway
Differential Revision: https://reviews.llvm.org/D152508
Added:
mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp
mlir/test/Dialect/ArmSME/vector-ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
Modified:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Target/LLVMIR/arm-sme.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index a69d32610357c8..dacf23ceca2de0 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_ARMSME_IR_ARMSME_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index d0072b6149cd93..140ed51b101b97 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -33,6 +33,7 @@ def ArmSME_Dialect : Dialect {
https://developer.arm.com/documentation/ddi0616
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
}];
+ let dependentDialects = ["scf::SCFDialect"];
}
//===----------------------------------------------------------------------===//
@@ -119,6 +120,11 @@ def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
+def LLVM_aarch64_sme_str
+ : ArmSME_IntrOp<"str">,
+ Arguments<(ins Arg<I32, "Index">,
+ Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
+
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index c8d9e6704ac573..fae04513859938 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -15,6 +15,11 @@ class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;
+namespace arm_sme {
+void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+} // namespace arm_sme
+
/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
/// intrinsics.
void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index acc4244ce9bb87..6ca7a7d84cfd80 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -109,6 +109,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
if (armSME) {
configureArmSMELegalizeForExportTarget(target);
populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
+ arm_sme::populateVectorTransferLoweringPatterns(converter, patterns);
}
if (amx) {
configureAMXLegalizeForExportTarget(target);
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index afe69de7133067..5b30531bc29bb5 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSMEDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
+ MLIRSCFDialect
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index efcb17fd2076a7..b9a6bc4fba4530 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArmSMETransforms
EnableArmStreaming.cpp
LegalizeForLLVMExport.cpp
+ LowerVectorOps.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -12,5 +13,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
MLIRArmSMEDialect
MLIRFuncDialect
MLIRLLVMCommonConversion
+ MLIRVectorDialect
+ MLIRSCFDialect
MLIRPass
)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index a946126fc8ea04..2eb061da49f440 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
using namespace mlir;
using namespace mlir::arm_sme;
@@ -51,7 +52,8 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
void mlir::configureArmSMELegalizeForExportTarget(
LLVMConversionTarget &target) {
- target.addLegalOp<arm_sme::aarch64_sme_za_enable,
+ target.addLegalOp<scf::ForOp, scf::YieldOp, arm_sme::aarch64_sme_zero,
+ arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_za_enable,
arm_sme::aarch64_sme_za_disable>();
// Mark 'func.func' ops as legal if either:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp
new file mode 100644
index 00000000000000..dfda09d2619e90
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp
@@ -0,0 +1,111 @@
+//===- LowerVectorOps.cpp - Lower vector ops to SME -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrite patterns to lower vector dialect ops to ArmSME.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+static constexpr unsigned kMinNumElts = 16;
+static constexpr unsigned kZeroZAMask = 255;
+
+/// Returns true if 'val' is a splat of zero, false otherwise.
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+ if (llvm::isa<FloatType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+ if (llvm::isa<IntegerType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+ return false;
+}
+
+namespace {
+/// Lower 'vector.transfer_write' op to 'arm_sme.intr.zero' op. Currently only
+/// supports 2d scalable vector type 'vector<[16x16]xi8>' that maps to the ZA0.B
+/// SME virtual tile. This will be extended to support more element types.
+struct TransferWriteToArmSMEZeroLowering
+ : public ConvertOpToLLVMPattern<vector::TransferWriteOp> {
+ using ConvertOpToLLVMPattern<vector::TransferWriteOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferWriteOp write, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto vType = write.getVectorType();
+ if (vType.getRank() != 2)
+ return failure();
+ if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
+ return failure();
+ if (vType.getElementType() != rewriter.getI8Type())
+ return failure();
+ if (vType.getScalableDims().size() != 2)
+ return failure();
+
+ auto memRefType = llvm::dyn_cast<MemRefType>(write.getSource().getType());
+ if (!memRefType)
+ return failure();
+
+ auto constant = write.getVector().getDefiningOp<arith::ConstantOp>();
+ if (!constant)
+ return failure();
+
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constant.getValueAttr());
+ if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
+ return failure();
+
+ auto loc = write.getLoc();
+
+ // Create 'arm_sme.intr.zero' intrinsic to zero ZA.
+ auto tile = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask));
+ rewriter.create<arm_sme::aarch64_sme_zero>(loc, tile);
+
+ // Create loop that iterates from 0 to SVLB-1 inclusive (the number of
+ // vectors in ZA) and stores each ZA vector to memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minElems = rewriter.create<arith::ConstantIndexOp>(loc, kMinNumElts);
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto upperBound = rewriter.create<arith::MulIOp>(loc, minElems, vscale);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ // Create 'arm_sme.intr.str' intrinsic to store ZA vector.
+ auto vnumI64 = rewriter.create<arith::IndexCastUIOp>(
+ loc, rewriter.getI64Type(), forOp.getInductionVar());
+ auto offset =
+ rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getSource(),
+ ValueRange{vnumI64, offset}, rewriter);
+ auto vnumI32 = rewriter.create<arith::IndexCastUIOp>(
+ loc, rewriter.getI32Type(), forOp.getInductionVar());
+ rewriter.create<arm_sme::aarch64_sme_str>(loc, vnumI32, ptr);
+
+ rewriter.eraseOp(write);
+
+ return success();
+ }
+};
+} // namespace
+
+void mlir::arm_sme::populateVectorTransferLoweringPatterns(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<TransferWriteToArmSMEZeroLowering>(converter);
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops.mlir b/mlir/test/Dialect/ArmSME/vector-ops.mlir
new file mode 100644
index 00000000000000..19b9896bc42a29
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/vector-ops.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: @transfer_write_2d_zero_i8
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
+// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C255:.*]] = arith.constant 255 : i32
+// CHECK-NEXT: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index
+// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index
+// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] {
+// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64
+// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64
+// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32
+// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> ()
+func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
+ vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+ return
+}
+
+// -----
+
+// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
+// lowering only occurs for vector types of correct rank, shape, element size
+// and number of scalable dims.
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_type
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__bad_type(%arg0 : memref<?x?xi4>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0> : vector<[16]x[16]xi4>
+ vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi4>, memref<?x?xi4>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_shape
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0> : vector<[8]x[8]xi8>
+ vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi8>, memref<?x?xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_rank
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8>
+ vector.transfer_write %cst, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16]x[16]x[16]xi8>, memref<?x?x?xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> tensor<?x?xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
+ %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
+ return %0 : tensor<?x?xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__non_zero_value
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref<?x?xi8>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<1> : vector<[16]x[16]xi8>
+ vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__vec_unknown_defining_op
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__vec_unknown_defining_op(%arg0 : memref<?x?xi8>, %arg1 : vector<[16]x[16]xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
new file mode 100644
index 00000000000000..70b53cfa8cf855
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -0,0 +1,142 @@
+// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
+// RUN: -convert-vector-to-llvm="enable-arm-sme" -test-lower-to-llvm | \
+// RUN: mlir-translate -mlir-to-llvmir | \
+// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \
+// RUN: --entry-function=entry \
+// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func.func @entry() -> i32 {
+ %c0 = arith.constant 0 : index
+ %c1_i8 = arith.constant 1 : i8
+ %c1_index = arith.constant 1 : index
+
+ %c16 = arith.constant 16 : index
+ %vscale = vector.vscale
+
+ // "svl" refers to the Streaming Vector Length and "svl_b" the number of
+ // 8-bit elements in a vector of SVL bits.
+ %svl_b = arith.muli %c16, %vscale : index
+
+ // Allocate memory and fill with ones.
+ //
+ // TODO: type conversion of rank > 1 vector types generates array(s) of
+ // vectors. This is invalid for scalable vectors since LLVM doesn't support
+ // arrays of scalable vectors. This prevents initializing 2-d vectors with
+ // 'vector.store' or 'vector.transfer_write' ops until this is resolved or
+ // there's a custom lowering path.
+ %za_b = memref.alloca(%svl_b, %svl_b) : memref<?x?xi8>
+ scf.for %i = %c0 to %svl_b step %c1_index {
+ scf.for %j = %c0 to %svl_b step %c1_index {
+ memref.store %c1_i8, %za_b[%i, %j] : memref<?x?xi8>
+ }
+ }
+
+ // Verify memory is ones by doing a mul reduction with initial value of one.
+ %init_1 = arith.constant 1 : i64
+ %mul_reduce = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_1) -> (i64) {
+ %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8>
+
+ %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
+ %t = vector.extractelement %row[%offset : index] : vector<[16]xi8>
+ %t_i64 = arith.extui %t : i8 to i64
+ %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
+ scf.yield %inner_mul_reduce_next : i64
+ }
+
+ %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
+ scf.yield %mul_reduce_next : i64
+ }
+
+ // CHECK: 1
+ vector.print %mul_reduce : i64
+
+ // Verify the mul reduction works as expected.
+ //
+ // TODO: ZA currently isn't re-enabled after calls and is therefore disable
+ // by the callee on return. Once this is resolved this can be moved to a
+ // function.
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : i8
+ %c7 = arith.constant 7 : index
+ %c15 = arith.constant 15 : i8
+ memref.store %c4, %za_b[%c3, %c7] : memref<?x?xi8>
+ memref.store %c15, %za_b[%c7, %c3] : memref<?x?xi8>
+ %mul_reduce2 = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_1) -> (i64) {
+ %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8>
+
+ %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
+ %t = vector.extractelement %row[%offset : index] : vector<[16]xi8>
+ %t_i64 = arith.extui %t : i8 to i64
+ %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
+ scf.yield %inner_mul_reduce_next : i64
+ }
+
+ %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
+ scf.yield %mul_reduce_next : i64
+ }
+
+ // 15*4=60
+ // CHECK: 60
+ vector.print %mul_reduce2 : i64
+
+ // Fill memory with zeroes.
+ //
+ // This will get lowered to:
+ //
+ // zero {za}
+ // for vnum = 0; vnum < SVLb; ++vnum;
+ // str za[vnum], [ptr]
+ // ...
+ //
+ %cst_0 = arith.constant dense<0> : vector<[16]x[16]xi8>
+ vector.transfer_write %cst_0, %za_b[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+
+ // Verify memory is zeroed by doing an add reduction with initial value of
+ // zero.
+ %init_0 = arith.constant 0 : i64
+ %add_reduce = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_0) -> (i64) {
+ %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8>
+
+ %inner_add_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_0) -> (i64) {
+ %t = vector.extractelement %row[%offset : index] : vector<[16]xi8>
+ %t_i64 = arith.extui %t : i8 to i64
+ %inner_add_reduce_next = arith.addi %inner_iter, %t_i64 : i64
+ scf.yield %inner_add_reduce_next : i64
+ }
+
+ %add_reduce_next = arith.addi %iter, %inner_add_reduce : i64
+ scf.yield %add_reduce_next : i64
+ }
+
+ // CHECK-NEXT: 0
+ vector.print %add_reduce : i64
+
+ // Verify the add reduction works as expected.
+ //
+ // TODO: ZA currently isn't re-enabled after calls and is therefore disable
+ // by the callee on return. Once this is resolved this can be moved to a
+ // function.
+ memref.store %c4, %za_b[%c3, %c7] : memref<?x?xi8>
+ memref.store %c15, %za_b[%c7, %c3] : memref<?x?xi8>
+ %add_reduce2 = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_0) -> (i64) {
+ %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8>
+
+ %inner_add_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_0) -> (i64) {
+ %t = vector.extractelement %row[%offset : index] : vector<[16]xi8>
+ %t_i64 = arith.extui %t : i8 to i64
+ %inner_add_reduce_next = arith.addi %inner_iter, %t_i64 : i64
+ scf.yield %inner_add_reduce_next : i64
+ }
+
+ %add_reduce_next = arith.addi %iter, %inner_add_reduce : i64
+ scf.yield %add_reduce_next : i64
+ }
+
+ // 15+4=19
+ // CHECK-NEXT: 19
+ vector.print %add_reduce2 : i64
+
+ %c0_i32 = arith.constant 0 : i32
+ return %c0_i32 : i32
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 453a887c6d7fde..7beec1f61aa923 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -221,6 +221,8 @@ llvm.func @arm_sme_store(%nxv1i1 : vector<[1]xi1>,
// CHECK: call void @llvm.aarch64.sme.st1b.vert
"arm_sme.intr.st1b.vert"(%nxv16i1, %p8, %c0, %c0) :
(vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
+ // CHECK: call void @llvm.aarch64.sme.str
+ "arm_sme.intr.str"(%c0, %p8) : (i32, !llvm.ptr<i8>) -> ()
llvm.return
}
More information about the Mlir-commits
mailing list