[Mlir-commits] [mlir] 564713c - [mlir][ArmSME] Add basic lowering of vector.transfer_write to zero

Cullen Rhodes llvmlistbot at llvm.org
Mon Jul 3 03:26:03 PDT 2023


Author: Cullen Rhodes
Date: 2023-07-03T10:18:43Z
New Revision: 564713c47175d9f61fe8e18e750fb7f7b486d533

URL: https://github.com/llvm/llvm-project/commit/564713c47175d9f61fe8e18e750fb7f7b486d533
DIFF: https://github.com/llvm/llvm-project/commit/564713c47175d9f61fe8e18e750fb7f7b486d533.diff

LOG: [mlir][ArmSME] Add basic lowering of vector.transfer_write to zero

This patch adds support for lowering a 'vector.transfer_write' of zeroes
and type 'vector<[16x16]xi8>' to the SME 'zero {za}' instruction [1],
which zeroes the entire accumulator, and then writing it out to memory
with the 'str' instruction [2].

This contributes to supporting a path from 'linalg.fill' to SME.

[1] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
[2] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/STR--Store-vector-from-ZA-array-

Reviewed By: awarzynski, dcaballe, WanderAway

Differential Revision: https://reviews.llvm.org/D152508

Added: 
    mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp
    mlir/test/Dialect/ArmSME/vector-ops.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
    mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
    mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
    mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Target/LLVMIR/arm-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index a69d32610357c8..dacf23ceca2de0 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_ARMSME_IR_ARMSME_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"

diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index d0072b6149cd93..140ed51b101b97 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -33,6 +33,7 @@ def ArmSME_Dialect : Dialect {
     https://developer.arm.com/documentation/ddi0616
     https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
   }];
+  let dependentDialects = ["scf::SCFDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -119,6 +120,11 @@ def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
 def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
 def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
 
+def LLVM_aarch64_sme_str
+    : ArmSME_IntrOp<"str">,
+      Arguments<(ins Arg<I32, "Index">,
+                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
+
 def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
 def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
 

diff  --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index c8d9e6704ac573..fae04513859938 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -15,6 +15,11 @@ class LLVMConversionTarget;
 class LLVMTypeConverter;
 class RewritePatternSet;
 
+namespace arm_sme {
+void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter,
+                                            RewritePatternSet &patterns);
+} // namespace arm_sme
+
 /// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
 /// intrinsics.
 void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index acc4244ce9bb87..6ca7a7d84cfd80 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -109,6 +109,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
   if (armSME) {
     configureArmSMELegalizeForExportTarget(target);
     populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
+    arm_sme::populateVectorTransferLoweringPatterns(converter, patterns);
   }
   if (amx) {
     configureAMXLegalizeForExportTarget(target);

diff  --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index afe69de7133067..5b30531bc29bb5 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSMEDialect
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRLLVMDialect
+  MLIRSCFDialect
   MLIRSideEffectInterfaces
 )

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index efcb17fd2076a7..b9a6bc4fba4530 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
   EnableArmStreaming.cpp
   LegalizeForLLVMExport.cpp
+  LowerVectorOps.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -12,5 +13,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
   MLIRArmSMEDialect
   MLIRFuncDialect
   MLIRLLVMCommonConversion
+  MLIRVectorDialect
+  MLIRSCFDialect
   MLIRPass
   )

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index a946126fc8ea04..2eb061da49f440 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 
 using namespace mlir;
 using namespace mlir::arm_sme;
@@ -51,7 +52,8 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
 
 void mlir::configureArmSMELegalizeForExportTarget(
     LLVMConversionTarget &target) {
-  target.addLegalOp<arm_sme::aarch64_sme_za_enable,
+  target.addLegalOp<scf::ForOp, scf::YieldOp, arm_sme::aarch64_sme_zero,
+                    arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_za_enable,
                     arm_sme::aarch64_sme_za_disable>();
 
   // Mark 'func.func' ops as legal if either:

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp
new file mode 100644
index 00000000000000..dfda09d2619e90
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp
@@ -0,0 +1,111 @@
+//===- LowerVectorOps.cpp - Lower vector ops to SME -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrite patterns to lower vector dialect ops to ArmSME.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+static constexpr unsigned kMinNumElts = 16;
+static constexpr unsigned kZeroZAMask = 255;
+
+/// Returns true if 'val' is a splat of zero, false otherwise.
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+  if (llvm::isa<FloatType>(elemType))
+    return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+  if (llvm::isa<IntegerType>(elemType))
+    return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+  return false;
+}
+
+namespace {
+/// Lower 'vector.transfer_write' op to 'arm_sme.intr.zero' op. Currently only
+/// supports 2d scalable vector type 'vector<[16x16]xi8>' that maps to the ZA0.B
+/// SME virtual tile. This will be extended to support more element types.
+struct TransferWriteToArmSMEZeroLowering
+    : public ConvertOpToLLVMPattern<vector::TransferWriteOp> {
+  using ConvertOpToLLVMPattern<vector::TransferWriteOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp write, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto vType = write.getVectorType();
+    if (vType.getRank() != 2)
+      return failure();
+    if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
+      return failure();
+    if (vType.getElementType() != rewriter.getI8Type())
+      return failure();
+    if (vType.getScalableDims().size() != 2)
+      return failure();
+
+    auto memRefType = llvm::dyn_cast<MemRefType>(write.getSource().getType());
+    if (!memRefType)
+      return failure();
+
+    auto constant = write.getVector().getDefiningOp<arith::ConstantOp>();
+    if (!constant)
+      return failure();
+
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constant.getValueAttr());
+    if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
+      return failure();
+
+    auto loc = write.getLoc();
+
+    // Create 'arm_sme.intr.zero' intrinsic to zero ZA.
+    auto tile = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask));
+    rewriter.create<arm_sme::aarch64_sme_zero>(loc, tile);
+
+    // Create loop that iterates from 0 to SVLB-1 inclusive (the number of
+    // vectors in ZA) and stores each ZA vector to memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto minElems = rewriter.create<arith::ConstantIndexOp>(loc, kMinNumElts);
+    auto vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = rewriter.create<arith::MulIOp>(loc, minElems, vscale);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    // Create 'arm_sme.intr.str' intrinsic to store ZA vector.
+    auto vnumI64 = rewriter.create<arith::IndexCastUIOp>(
+        loc, rewriter.getI64Type(), forOp.getInductionVar());
+    auto offset =
+        rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
+    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getSource(),
+                                     ValueRange{vnumI64, offset}, rewriter);
+    auto vnumI32 = rewriter.create<arith::IndexCastUIOp>(
+        loc, rewriter.getI32Type(), forOp.getInductionVar());
+    rewriter.create<arm_sme::aarch64_sme_str>(loc, vnumI32, ptr);
+
+    rewriter.eraseOp(write);
+
+    return success();
+  }
+};
+} // namespace
+
+void mlir::arm_sme::populateVectorTransferLoweringPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+  patterns.add<TransferWriteToArmSMEZeroLowering>(converter);
+}

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops.mlir b/mlir/test/Dialect/ArmSME/vector-ops.mlir
new file mode 100644
index 00000000000000..19b9896bc42a29
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/vector-ops.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: @transfer_write_2d_zero_i8
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
+// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C255:.*]] = arith.constant 255 : i32
+// CHECK-NEXT: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index
+// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index
+// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] {
+// CHECK-NEXT:   %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64
+// CHECK-NEXT:   %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]]  : i64
+// CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]]  : i64
+// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT:   %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32
+// CHECK-NEXT:   "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> ()
+func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
+  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+  return
+}
+
+// -----
+
+// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
+// lowering only occurs for vector types of correct rank, shape, element size
+// and number of scalable dims.
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_type
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__bad_type(%arg0 : memref<?x?xi4>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0> : vector<[16]x[16]xi4>
+  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi4>, memref<?x?xi4>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_shape
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0> : vector<[8]x[8]xi8>
+  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi8>, memref<?x?xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_rank
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8>
+  vector.transfer_write %cst, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16]x[16]x[16]xi8>, memref<?x?x?xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> tensor<?x?xi8> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
+  %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
+  return %0 : tensor<?x?xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__non_zero_value
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<1> : vector<[16]x[16]xi8>
+  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__vec_unknown_defining_op
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__vec_unknown_defining_op(%arg0 : memref<?x?xi8>, %arg1 : vector<[16]x[16]xi8>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg1, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
new file mode 100644
index 00000000000000..70b53cfa8cf855
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -0,0 +1,142 @@
+// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
+// RUN:   -convert-vector-to-llvm="enable-arm-sme" -test-lower-to-llvm | \
+// RUN: mlir-translate -mlir-to-llvmir | \
+// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \
+// RUN:   --entry-function=entry \
+// RUN:   --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func.func @entry() -> i32 {
+  %c0 = arith.constant 0 : index
+  %c1_i8 = arith.constant 1 : i8
+  %c1_index = arith.constant 1 : index
+
+  %c16 = arith.constant 16 : index
+  %vscale = vector.vscale
+
+  // "svl" refers to the Streaming Vector Length and "svl_b" the number of
+  // 8-bit elements in a vector of SVL bits.
+  %svl_b = arith.muli %c16, %vscale : index
+
+  // Allocate memory and fill with ones.
+  //
+  // TODO: type conversion of rank > 1 vector types generates array(s) of
+  // vectors. This is invalid for scalable vectors since LLVM doesn't support
+  // arrays of scalable vectors. This prevents initializing 2-d vectors with
+  // 'vector.store' or 'vector.transfer_write' ops until this is resolved or
+  // there's a custom lowering path.
+  %za_b = memref.alloca(%svl_b, %svl_b) : memref<?x?xi8>
+  scf.for %i = %c0 to %svl_b step %c1_index {
+    scf.for %j = %c0 to %svl_b step %c1_index {
+      memref.store %c1_i8, %za_b[%i, %j] : memref<?x?xi8>
+    }
+  }
+
+  // Verify memory is ones by doing a mul reduction with initial value of one.
+  %init_1 = arith.constant 1 : i64
+  %mul_reduce = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_1) -> (i64) {
+    %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8>
+
+    %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
+      %t = vector.extractelement %row[%offset : index] : vector<[16]xi8>
+      %t_i64 = arith.extui %t : i8 to i64
+      %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
+      scf.yield %inner_mul_reduce_next : i64
+    }
+
+    %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
+    scf.yield %mul_reduce_next : i64
+  }
+
+  // CHECK: 1
+  vector.print %mul_reduce : i64
+
+  // Verify the mul reduction works as expected.
+  //
+  // TODO: ZA currently isn't re-enabled after calls and is therefore disable
+  // by the callee on return. Once this is resolved this can be moved to a
+  // function.
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : i8
+  %c7 = arith.constant 7 : index
+  %c15 = arith.constant 15 : i8
+  memref.store %c4, %za_b[%c3, %c7] : memref<?x?xi8>
+  memref.store %c15, %za_b[%c7, %c3] : memref<?x?xi8>
+  %mul_reduce2 = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_1) -> (i64) {
+    %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8>
+
+    %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
+      %t = vector.extractelement %row[%offset : index] : vector<[16]xi8>
+      %t_i64 = arith.extui %t : i8 to i64
+      %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
+      scf.yield %inner_mul_reduce_next : i64
+    }
+
+    %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
+    scf.yield %mul_reduce_next : i64
+  }
+
+  // 15*4=60
+  // CHECK: 60
+  vector.print %mul_reduce2 : i64
+
+  // Fill memory with zeroes.
+  //
+  // This will get lowered to:
+  //
+  //   zero {za}
+  //   for vnum = 0; vnum < SVLb; ++vnum;
+  //     str za[vnum], [ptr]
+  //     ...
+  //
+  %cst_0 = arith.constant dense<0> : vector<[16]x[16]xi8>
+  vector.transfer_write %cst_0, %za_b[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+
+  // Verify memory is zeroed by doing an add reduction with initial value of
+  // zero.
+  %init_0 = arith.constant 0 : i64
+  %add_reduce = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_0) -> (i64) {
+    %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8>
+
+    %inner_add_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_0) -> (i64) {
+      %t = vector.extractelement %row[%offset : index] : vector<[16]xi8>
+      %t_i64 = arith.extui %t : i8 to i64
+      %inner_add_reduce_next = arith.addi %inner_iter, %t_i64 : i64
+      scf.yield %inner_add_reduce_next : i64
+    }
+
+    %add_reduce_next = arith.addi %iter, %inner_add_reduce : i64
+    scf.yield %add_reduce_next : i64
+  }
+
+  // CHECK-NEXT: 0
+  vector.print %add_reduce : i64
+
+  // Verify the add reduction works as expected.
+  //
+  // TODO: ZA currently isn't re-enabled after calls and is therefore disable
+  // by the callee on return. Once this is resolved this can be moved to a
+  // function.
+  memref.store %c4, %za_b[%c3, %c7] : memref<?x?xi8>
+  memref.store %c15, %za_b[%c7, %c3] : memref<?x?xi8>
+  %add_reduce2 = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_0) -> (i64) {
+    %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8>
+
+    %inner_add_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_0) -> (i64) {
+      %t = vector.extractelement %row[%offset : index] : vector<[16]xi8>
+      %t_i64 = arith.extui %t : i8 to i64
+      %inner_add_reduce_next = arith.addi %inner_iter, %t_i64 : i64
+      scf.yield %inner_add_reduce_next : i64
+    }
+
+    %add_reduce_next = arith.addi %iter, %inner_add_reduce : i64
+    scf.yield %add_reduce_next : i64
+  }
+
+  // 15+4=19
+  // CHECK-NEXT: 19
+  vector.print %add_reduce2 : i64
+
+  %c0_i32 = arith.constant 0 : i32
+  return %c0_i32 : i32
+}

diff  --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 453a887c6d7fde..7beec1f61aa923 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -221,6 +221,8 @@ llvm.func @arm_sme_store(%nxv1i1  : vector<[1]xi1>,
   // CHECK: call void @llvm.aarch64.sme.st1b.vert
   "arm_sme.intr.st1b.vert"(%nxv16i1, %p8, %c0, %c0) :
               (vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.str
+  "arm_sme.intr.str"(%c0, %p8) : (i32, !llvm.ptr<i8>) -> ()
   llvm.return
 }
 


        


More information about the Mlir-commits mailing list