[Mlir-commits] [mlir] [mlir][ArmSME] Add support for lowering masked tile_load ops (PR #70915)
Cullen Rhodes
llvmlistbot at llvm.org
Wed Nov 1 02:18:58 PDT 2023
https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/70915
>From b01f92fa817fbe3e2633e1e693ac8c66a7097339 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 08:47:40 +0000
Subject: [PATCH 1/3] [mlir][ArmSME] Propagate pad and mask in
vector.transfer_read lowering
This extends the lowering of vector.transfer_read -> arm_sme.tile_load
lowering to propagate pad and mask.
The restriction on the transfer_read being a transposition is also
removed, identity maps are lowered to normal horizontal loads.
---
.../VectorToArmSME/VectorToArmSME.cpp | 57 ++++---
.../Dialect/ArmSME/vector-ops-to-sme.mlir | 140 +++++++++---------
2 files changed, 109 insertions(+), 88 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 005dd546bf1632b..5491f7dd30629ad 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -60,15 +60,30 @@ getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
namespace {
-/// Conversion pattern for vector.transfer_read op with transpose permutation
-/// map to vertical arm_sme.tile_load (in-flight transpose).
+/// Conversion pattern for vector.transfer_read.
+///
+/// ---
+///
+/// Example 1: op with identity permutation map to horizontal
+/// arm_sme.tile_load:
+///
+/// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
+///
+/// is converted to:
+///
+/// arm_sme.tile_load ...
+///
+/// ---
+///
+/// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
+/// (in-flight transpose):
///
/// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
///
/// is converted to:
///
/// arm_sme.tile_load ... layout<vertical>
-struct TransferReadPermutationToArmSMELowering
+struct TransferReadToArmSMELowering
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
@@ -79,15 +94,6 @@ struct TransferReadPermutationToArmSMELowering
return rewriter.notifyMatchFailure(transferReadOp,
"not a 2 result permutation map");
- AffineMap map = transferReadOp.getPermutationMap();
-
- // Permutation map doesn't perform permutation, can be lowered to
- // vector.load by TransferReadToVectorLoadLowering and then
- // arm_sme.tile_load by VectorLoadToArmSMELowering.
- if (map.isIdentity())
- return rewriter.notifyMatchFailure(
- transferReadOp, "map is an identity, apply another pattern");
-
auto vectorType = transferReadOp.getVectorType();
if (!arm_sme::isValidSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(transferReadOp,
@@ -96,26 +102,33 @@ struct TransferReadPermutationToArmSMELowering
if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
- if (transferReadOp.getMask())
- // TODO: support masking.
- return rewriter.notifyMatchFailure(transferReadOp,
- "masking not yet supported");
-
// Out-of-bounds dims are not supported.
if (transferReadOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(transferReadOp,
"not inbounds transfer read");
+ arm_sme::TileSliceLayout layout;
+
AffineExpr d0, d1;
bindDims(transferReadOp.getContext(), d0, d1);
- if (map != AffineMap::get(map.getNumDims(), 0, {d1, d0},
- transferReadOp.getContext()))
+ AffineMap map = transferReadOp.getPermutationMap();
+ if (map.isIdentity())
+ layout = arm_sme::TileSliceLayout::Horizontal;
+ else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
+ transferReadOp.getContext()))
+ layout = arm_sme::TileSliceLayout::Vertical;
+ else
return rewriter.notifyMatchFailure(transferReadOp,
- "not true 2-D matrix transpose");
+ "unsupported permutation map");
+ // Padding isn't optional for transfer_read, but is only used in the case
+ // of out-of-bounds accesses (not supported here) and/or masking. Mask is
+ // optional, if it's not present don't pass padding.
+ auto mask = transferReadOp.getMask();
+ auto padding = mask ? transferReadOp.getPadding() : nullptr;
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
transferReadOp, vectorType, transferReadOp.getSource(),
- transferReadOp.getIndices(), arm_sme::TileSliceLayout::Vertical);
+ transferReadOp.getIndices(), padding, mask, layout);
return success();
}
@@ -531,7 +544,7 @@ struct VectorOuterProductToArmSMELowering
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
- SplatOpToArmSMELowering, TransferReadPermutationToArmSMELowering,
+ SplatOpToArmSMELowering, TransferReadToArmSMELowering,
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
VectorOuterProductToArmSMELowering>(&ctx);
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 5f41313fc6ac789..ed33f8508dba0bf 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,181 +1,189 @@
// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
//===----------------------------------------------------------------------===//
-// vector.transfer_read (with in-flight transpose)
+// vector.transfer_read
//===----------------------------------------------------------------------===//
-// CHECK-LABEL: @transfer_read_2d_transpose_i8
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
-func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
+// CHECK-LABEL: @transfer_read_2d_i8
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_read_2d_i8(%src : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i8
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d_transpose_i16
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
-func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) {
+// CHECK-LABEL: @transfer_read_2d_i16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_read_2d_i16(%src : memref<?x?xi16>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i16
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
"prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d_transpose_i32
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
-func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) {
+// CHECK-LABEL: @transfer_read_2d_i32
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @transfer_read_2d_i32(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i32
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
"prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d_transpose_i64
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
-func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) {
+// CHECK-LABEL: @transfer_read_2d_i64
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_read_2d_i64(%src : memref<?x?xi64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i64
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
"prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d_transpose_i128
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
-func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) {
+// CHECK-LABEL: @transfer_read_2d_i128
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @transfer_read_2d_i128(%src : memref<?x?xi128>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i128
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
"prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d_transpose_f16
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
-func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) {
+// CHECK-LABEL: @transfer_read_2d_f16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @transfer_read_2d_f16(%src : memref<?x?xf16>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f16
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d_transpose_bf16
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
-func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) {
+// CHECK-LABEL: @transfer_read_2d_bf16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_read_2d_bf16(%src : memref<?x?xbf16>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : bf16
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
"prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d_transpose_f32
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
-func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) {
+// CHECK-LABEL: @transfer_read_2d_f32
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_read_2d_f32(%src : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d_transpose_f64
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
-func.func @transfer_read_2d_transpose_f64(%src : memref<?x?xf64>) {
+// CHECK-LABEL: @transfer_read_2d_f64
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_read_2d_f64(%src : memref<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d__bad_type
-// CHECK-NOT: arm_sme.tile_load
-// CHECK: vector.transfer_read
-func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
+// CHECK-LABEL: @transfer_read_2d_with_mask_i16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_read_2d_with_mask_i16(%src : memref<?x?xi16>, %mask : vector<[8]x[8]xi1>) {
%c0 = arith.constant 0 : index
- %pad = arith.constant 0.0 : f64
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
- "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
+ %pad = arith.constant 0 : i16
+ %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d__non_memref_type
-// CHECK-NOT: arm_sme.tile_load
-// CHECK: vector.transfer_read
-func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
+/// in-flight transpose
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i8
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
- %pad = arith.constant 0.0 : f64
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
- "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+ %pad = arith.constant 0 : i8
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
+// CHECK-LABEL: @transfer_read_2d_transpose_with_mask_f32
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_type
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
-func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
+func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
- "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
+ "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
return
}
// -----
-// CHECK-LABEL: @transfer_read_2d__unsupported_mask
+// CHECK-LABEL: @transfer_read_2d__non_memref_type
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
-func.func @transfer_read_2d__unsupported_mask(%src : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
+func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
- %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
return
}
// -----
-/// transfer_read with identity map should be lowered to vector.load by
-/// TransferReadToVectorLoadLowering and then arm_sme.tile_load by
-/// VectorLoadToArmSMELowering.
-
-// CHECK-LABEL: @transfer_read_2d__non_permuting_map
+// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
-func.func @transfer_read_2d__non_permuting_map(%src : memref<?x?xf64>) {
+func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
- %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
- "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
+ "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
return
}
>From 61283b7837488d0e0234bfc1c6bf9baa53cba183 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 09:15:52 +0000
Subject: [PATCH 2/3] [mlir][ArmSME] Add support for lowering masked tile_load
ops
This patch extends ArmSMEToSCF to support lowering of masked tile_load
ops. Only masks created by 'vector.create_mask' are currently supported.
There are two lowerings, one for pad of constant zero and another for
non-zero pad. For the following example:
%pad = arith.constant 0 : i32
%num_rows = arith.constant 2 : index
%num_cols = arith.constant 4 : index
%mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1>
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>,
vector<[4]x[4]xi32>
The former (constant non-zero pad) is lowered as follows:
---------------------------------------------------------
%tile = arm_sme.zero : vector<[4]x[4]xi32>
%num_cols = vector.create_mask %c4 : vector<[4]xi1>
scf.for %slice_idx = %c0 to %num_rows step %c1
%tile_update = arm_sme.load_tile_slice
%src[%slice_idx], %num_cols, %tile, %tile_slice_idx :
memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
The tile is zeroed the satisfy the padding and only active rows are
loaded.
The latter (non-zero pad) is lowered as follows:
------------------------------------------------
scf.for %slice_idx = %c0 to %num_tile_slices step %c1 {
%row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
%slice = scf.if %row_is_active -> vector<[4]xf32> {
%slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d :
memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
scf.yield %slice : vector<[4]xf32>
} else {
scf.yield %pad_1d : vector<[4]xf32>
}
arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx
: vector<[4]xi32> into vector<[4]x[4]xi32>
The scalar pad is broadcast to a 1-D vector and a regular
'vector.masked_load' (will be lowered to SVE, not SME) loads each slice
for active rows, with padding specified as a passthru. For non-active
rows the slice is the 1-D pad. The resulting slice is inserted into the
tile with 'arm_sme.move_vector_to_tile_slice'.
---
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 254 +++++++++++++++++-
.../ArmSMEToSCF/arm-sme-to-scf.mlir | 56 ++++
.../CPU/ArmSME/test-transfer-read-2d.mlir | 212 +++++++++++++++
3 files changed, 519 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 50cc818f1ffc090..491c7604433ddf1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -80,9 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
if (tileLoadOp.getMask())
- // TODO: add masked patterns.
return rewriter.notifyMatchFailure(
- tileLoadOp, "op has mask, needs masked pattern(s)");
+ tileLoadOp, "op has mask, apply masked patterns");
OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
@@ -142,6 +141,254 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
}
};
+/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+///
+/// BEFORE:
+/// ```mlir
+/// %pad = arith.constant 0 : i32
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+/// memref<?x?xi32>, vector<[4]x[4]xi32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
+/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
+/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+/// %tile_update = arm_sme.load_tile_slice
+/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+/// }
+/// ```
+///
+/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
+struct TileLoadOpWithMaskAndPadZeroConversion
+ : public OpRewritePattern<arm_sme::TileLoadOp> {
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+ PatternRewriter &rewriter) const override {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = tileLoadOp.getLoc();
+ auto tileType = tileLoadOp.getVectorType();
+
+ auto maskOp = tileLoadOp.getMask();
+ if (!maskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has no mask, needs unmasked pattern");
+
+ auto padOp = tileLoadOp.getPadding();
+ assert(padOp && "expected padding when masking!");
+
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+
+ auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+ if (!constPadOp || constPadOp.getValue() !=
+ rewriter.getZeroAttr(tileType.getElementType()))
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
+
+ auto numRows = createMaskOp.getOperands()[0];
+ auto numCols = createMaskOp.getOperands()[1];
+
+ auto predicateType =
+ VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+ auto numColsOp =
+ rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+ // Initialize tile with zero to satisfy padding. Inactive cols will be
+ // zeroed anyway since the loads use zeroing predication. For inactive rows
+ // however, no load will occur so these need to be zeroed.
+ auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+
+ // Create a loop to load the active tile slices from memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto upperBound = numRows;
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
+ // tile.
+ SmallVector<Value> memrefIndices;
+ auto tileSliceIndex = forOp.getInductionVar();
+ getMemrefIndices(tileLoadOp.getIndices(),
+ tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+ upperBound, memrefIndices, loc, rewriter);
+ rewriter.create<arm_sme::LoadTileSliceOp>(
+ loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
+ tileSliceIndex, tileLoadOp.getLayout());
+
+ rewriter.setInsertionPointAfter(forOp);
+
+ // Replace 'arm_sme.tile_load' with the tile.
+ rewriter.replaceOp(tileLoadOp, tile);
+
+ return success();
+ }
+};
+
+/// Lower `arm_sme.tile_load` with mask and non-zero pad.
+///
+/// BEFORE:
+/// ```mlir
+/// %pad = arith.constant 1 : i32
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+/// memref<?x?xi32>, vector<[4]x[4]xi32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %tile_id = arm_sme.get_tile_id : i32
+/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+/// %vscale = vector.vscale
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %min_svl_s = arith.constant 4 : index
+/// %svl_s = arith.muli %min_svl_s, %vscale : index
+/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+/// %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
+/// %slice = scf.if %row_is_active -> vector<[4]xi32> {
+/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
+/// : memref<?x?xi32>, vector<[4]xi1>,
+/// vector<[4]xi32> into vector<[4]xi32>
+/// scf.yield %slice : vector<[4]xi32>
+/// } else {
+/// scf.yield %pad_1d : vector<[4]xi32>
+/// }
+/// // Insert slice into tile
+/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
+/// : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// }
+/// ```
+struct TileLoadOpWithMaskAndPadNonZeroConversion
+ : public OpRewritePattern<arm_sme::TileLoadOp> {
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+ PatternRewriter &rewriter) const override {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = tileLoadOp.getLoc();
+ auto tileType = tileLoadOp.getVectorType();
+ auto tileElementType = tileType.getElementType();
+ unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+ auto maskOp = tileLoadOp.getMask();
+ if (!maskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has no mask, needs unmasked pattern");
+
+ auto padOp = tileLoadOp.getPadding();
+ assert(padOp && "expected padding when masking!");
+
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+
+ auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+ if (constPadOp &&
+ constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has constant zero pad, needs zero pad pattern");
+
+ auto numRows = createMaskOp.getOperands()[0];
+ auto numCols = createMaskOp.getOperands()[1];
+
+ VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+ auto predicateType =
+ VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+ auto numColsOp =
+ rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+ // Create 'arm_sme.get_tile' op.
+ auto tileId = rewriter.create<arm_sme::GetTileID>(
+ loc, rewriter.getIntegerType(tileElementWidth));
+
+ // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
+ // use as input tile to 'arm_sme.load_tile_slice' ops.
+ auto tile =
+ rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+
+ // Create a loop that loads each ZA tile slice from memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ auto forOp =
+ rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ auto tileSliceIndex = forOp.getInductionVar();
+
+ auto rowIsActive = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+
+ SmallVector<Value> memrefIndices;
+ getMemrefIndices(tileLoadOp.getIndices(),
+ tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+ numTileSlices, memrefIndices, loc, rewriter);
+
+ // Splat pad into 1-D vector matching type of tile slice.
+ auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+
+ Operation *slice = rewriter.create<scf::IfOp>(
+ loc, rowIsActive,
+ [&](OpBuilder &b, Location loc) {
+ // If the row is active, emit a masked load where the predicate is
+ // 'numCols'. Pad is used for inactive elements, taken from
+ // passthru.
+ auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
+ loc, tileSliceType, tileLoadOp.getBase(), memrefIndices,
+ numColsOp, /*passthru=*/pad1DOp);
+ rewriter.create<scf::YieldOp>(loc, loadSlice->getResult(0));
+ },
+ [&](OpBuilder &b, Location loc) {
+ // Inactive rows are filled with pad.
+ rewriter.create<scf::YieldOp>(loc, pad1DOp.getResult());
+ });
+
+ // TODO: If the load is vertical the transpose can't be done in-flight with
+ // a regular (SVE) maskedload. Propagate layout to
+ // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
+ // is currently broken.
+
+ // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
+ rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, slice->getResult(0), tile, tileSliceIndex,
+ tileLoadOp.getLayout());
+
+ rewriter.setInsertionPointAfter(forOp);
+
+ // Replace 'arm_sme.tile_load' with the tile.
+ rewriter.replaceOp(tileLoadOp, tile);
+
+ return success();
+ }
+};
+
/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
/// slice using `arm_sme.store_tile_slice`.
///
@@ -266,7 +513,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
} // namespace
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
- patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+ patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+ TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
TileVectorPrintOpConversion>(patterns.getContext());
}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 3fb320c0d219e60..4906812032ae903 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -33,6 +33,62 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
return
}
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG: %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant 0 : i32
+ %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>,
+// CHECK-SAME: %[[PAD:.*]]: i32) {
+// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT: %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
+// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK: %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) {
+// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK: scf.yield %[[LOAD_SLICE]] : vector<[4]xi32>
+// CHECK: } else {
+// CHECK: scf.yield %[[PAD_1D]] : vector<[4]xi32>
+// CHECK: }
+// CHECK: arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.tile_store
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
new file mode 100644
index 000000000000000..644f90d950645b8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -0,0 +1,212 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE: -march=aarch64 -mattr=+sve,+sme \
+// DEFINE: -e %{entry_point} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// Vector load.
+func.func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c4 = arith.constant 4 : index
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %A[%base1, %base2], %pad {in_bounds=[true, true]} :
+ memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load + transpose.
+func.func @transfer_read_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %A[%base1, %base2], %pad
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+ : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0 : vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and pad of zero.
+func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant 0.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and pad of zero + transpose.
+func.func @transfer_read_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant 0.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+ : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and non-zero pad.
+func.func @transfer_read_2d_mask_non_zero_pad(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant -42.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and non-zero pad + transpose.
+func.func @transfer_read_2d_mask_non_zero_pad_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant -42.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+ : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Allocate heap memory of size 'd0' x 'd1' and initialize.
+//
+// Example:
+//
+// initialize_memory(%c4, %c5)
+//
+// 0, 1, 2, 3, 4
+// 10, 11, 12, 13, 14
+// 20, 21, 22, 23, 24
+// 30, 31, 32, 33, 34
+//
+// Returns dynamic memref. It's the callers responsiblity to free the returned
+// memref.
+func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c1_f32 = arith.constant 1.0 : f32
+ %c10_f32 = arith.constant 10.0 : f32
+
+ %A = memref.alloc(%d0, %d1) : memref<?x?xf32>
+
+ %init = arith.constant 0.0 : f32
+ scf.for %i = %c0 to %d0 step %c1 iter_args(%val = %init) -> f32 {
+ scf.for %j = %c0 to %d1 step %c1 iter_args(%inner_val = %val) -> f32 {
+ memref.store %inner_val, %A[%i, %j] : memref<?x?xf32>
+ %inner_val_next = arith.addf %inner_val, %c1_f32 : f32
+ scf.yield %inner_val_next : f32
+ }
+ %val_next = arith.addf %val, %c10_f32 : f32
+ scf.yield %val_next : f32
+ }
+
+ return %A : memref<?x?xf32>
+}
+
+func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+
+ // Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
+ // non-zero offsets while remaining inbounds.
+ %vscale = vector.vscale
+ %svl_s = arith.muli %c4, %vscale : index
+ %svl_s_plus_two = arith.addi %svl_s, %c2 : index
+
+ %A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
+
+ // 1.a. Read 2D vector from 2D memref.
+ //
+ // CHECK-LABEL: TILE BEGIN:
+ // CHECK-NEXT: ( 0, 1, 2, 3
+ // CHECK-NEXT: ( 10, 11, 12, 13
+ // CHECK-NEXT: ( 20, 21, 22, 23
+ // CHECK-NEXT: ( 30, 31, 32, 33
+ call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+ // 1.b. Same as 1.a., but with non-zero offsets.
+ //
+ // CHECK-LABEL: TILE BEGIN:
+ // CHECK-NEXT: ( 12, 13, 14, 15
+ // CHECK-NEXT: ( 22, 23, 24, 25
+ // CHECK-NEXT: ( 32, 33, 34, 35
+ // CHECK-NEXT: ( 42, 43, 44, 45
+ call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
+
+ // 2. Same as 1.a., but with mask and a pad of constant zero.
+ // CHECK-LABEL: TILE BEGIN:
+ // CHECK-NEXT: ( 0, 1, 2, 0
+ // CHECK-NEXT: ( 10, 11, 12, 0
+ // CHECK-NEXT: ( 0, 0, 0, 0
+ // CHECK-NEXT: ( 0, 0, 0, 0
+ call @transfer_read_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+ // 3. Same as 1.a., but with mask and non-zero pad.
+ // CHECK-LABEL: TILE BEGIN:
+ // CHECK-NEXT: ( 0, 1, 2, -42
+ // CHECK-NEXT: ( 10, 11, 12, -42
+ // CHECK-NEXT: ( -42, -42, -42, -42
+ // CHECK-NEXT: ( -42, -42, -42, -42
+ call @transfer_read_2d_mask_non_zero_pad(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+ // 4. Same as 1.a., but transpose the result.
+ // CHECK-LABEL: TILE BEGIN:
+ // CHECK-NEXT: ( 0, 10, 20, 30
+ // CHECK-NEXT: ( 1, 11, 21, 31
+ // CHECK-NEXT: ( 2, 12, 22, 32
+ // CHECK-NEXT: ( 3, 13, 23, 33
+ call @transfer_read_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+ // 5. Same as 2., but transpose the result.
+ // CHECK-LABEL: TILE BEGIN:
+ // CHECK-NEXT: ( 0, 10, 0, 0
+ // CHECK-NEXT: ( 1, 11, 0, 0
+ // CHECK-NEXT: ( 2, 12, 0, 0
+ // CHECK-NEXT: ( 0, 0, 0, 0
+ call @transfer_read_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+ // 5. Same as 3, but transpose the result.
+ // CHECK-LABEL: TILE BEGIN:
+ // CHECK-NEXT: ( 0, 10, -42, -42
+ // CHECK-NEXT: ( 1, 11, -42, -42
+ // CHECK-NEXT: ( 2, 12, -42, -42
+ // CHECK-NEXT: ( -42, -42, -42, -42
+ call @transfer_read_2d_mask_non_zero_pad_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+ memref.dealloc %A : memref<?x?xf32>
+
+ return
+}
>From d3602364195412e0963680ea5ad275f0423e4ab2 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 1 Nov 2023 09:18:42 +0000
Subject: [PATCH 3/3] run clang-format
---
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 491c7604433ddf1..4e5021a65a2395b 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -80,8 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
if (tileLoadOp.getMask())
- return rewriter.notifyMatchFailure(
- tileLoadOp, "op has mask, apply masked patterns");
+ return rewriter.notifyMatchFailure(tileLoadOp,
+ "op has mask, apply masked patterns");
OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
More information about the Mlir-commits
mailing list