[Mlir-commits] [mlir] [mlir][ArmSME] Lower transfer_write + transpose to vertical store (PR #71181)
Cullen Rhodes
llvmlistbot at llvm.org
Wed Nov 8 04:59:18 PST 2023
https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/71181
>From bee71e7618a7cc4246b86d09f4f281695ef874dc Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 16 Oct 2023 11:43:32 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Lower transfer_write + transpose to
vertical store
This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.
---
.../VectorToArmSME/VectorToArmSME.cpp | 47 +++++++++++++++++--
.../Dialect/ArmSME/vector-ops-to-sme.mlir | 42 +++++++++++++++++
.../CPU/ArmSME/test-transfer-write-2d.mlir | 38 +++++++++++++++
3 files changed, 124 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 5491f7dd30629ad..a8956f0d38fba9d 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -136,13 +136,31 @@ struct TransferReadToArmSMELowering
/// Conversion pattern for vector.transfer_write.
///
-/// vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
-/// memref<?x?xi8>
+/// ---
+///
+/// Example 1: op with identity permutation map to horizontal
+/// arm_sme.tile_store:
+///
+/// vector.transfer_write %vector, %source[%c0, %c0]
+/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
///
/// is converted to:
///
/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
/// vector<[16]x[16]xi8>
+/// ---
+///
+/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
+/// (in-flight transpose):
+///
+/// vector.transfer_write %vector, %source[%c0, %c0]
+/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+/// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+///
+/// is converted to:
+///
+/// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
+/// : memref<?x?xi8>, vector<[16]x[16]xi8>
struct TransferWriteToArmSMELowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -153,12 +171,35 @@ struct TransferWriteToArmSMELowering
if (!arm_sme::isValidSMETileVectorType(vType))
return failure();
+ assert(writeOp.getTransferRank() == 2 &&
+ "expected a permutation_map with result dims of the same rank as "
+ "the vector type");
+
if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
return failure();
+ // Out-of-bounds dims are not supported.
+ if (writeOp.hasOutOfBoundsDim())
+ return rewriter.notifyMatchFailure(writeOp,
+ "not inbounds transfer write");
+
+ arm_sme::TileSliceLayout layout;
+
+ AffineExpr d0, d1;
+ bindDims(writeOp.getContext(), d0, d1);
+ AffineMap map = writeOp.getPermutationMap();
+ if (map.isIdentity())
+ layout = arm_sme::TileSliceLayout::Horizontal;
+ else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
+ writeOp.getContext()))
+ layout = arm_sme::TileSliceLayout::Vertical;
+ else
+ return rewriter.notifyMatchFailure(writeOp,
+ "unsupported permutation map");
+
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
- writeOp.getMask());
+ writeOp.getMask(), layout);
return success();
}
};
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index ed33f8508dba0bf..a1ad25ed77aa8ef 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -337,6 +337,37 @@ func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest
// -----
+/// in-flight transpose via vertical store.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi64>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
+ return
+}
+
+// -----
+
+/// in-flight transpose via vertical store with mask.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xbf16>,
+// CHECK-SAME: %[[MASK:.*]]: vector<[8]x[8]xi1>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>, %mask : vector<[8]x[8]xi1>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
+ return
+}
+
+// -----
+
// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
// lowering only occurs for vector types of correct rank, shape, element size
// and number of scalable dims.
@@ -398,6 +429,17 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
return
}
+// -----
+
+// CHECK-LABEL: @transfer_write_2d__out_of_bounds
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ return
+}
+
//===----------------------------------------------------------------------===//
// vector.broadcast
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index b599b976c3e1592..174cec857437a7d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -32,6 +32,25 @@ func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: i
return
}
+// Vector store + transpose.
+func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+ vector<[4]x[4]xf32>, memref<?x?xf32>
+ return
+}
+
+// Masked vector store + transpose.
+func.func @transfer_write_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+ %mask = vector.create_mask %c4, %c2 : vector<[4]x[4]xi1>
+ %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ vector.transfer_write %0, %A[%base1, %base2], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+ vector<[4]x[4]xf32>, memref<?x?xf32>
+ return
+}
+
// Vector load + print.
func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -116,6 +135,25 @@ func.func @entry() {
call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+ // 4. Reload 3. + store + transpose.
+ // CHECK-LABEL: TILE BEGIN:
+ // CHECK-NEXT: ( 0, 0, 20, 30
+ // CHECK-NEXT: ( 0, 0, 21, 31
+ // CHECK-NEXT: ( 0, 0, 0, 0
+ // CHECK-NEXT: ( 3, 13, 0, 0
+ call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+ call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+ // 5. Reload 4. + store + transpose but with mask (nrows=4, ncols=2).
+ // The mask applies after permutation
+ // CHECK-LABEL: TILE BEGIN:
+ // CHECK-NEXT: ( 0, 0, 20, 30
+ // CHECK-NEXT: ( 0, 0, 21, 31
+ // CHECK-NEXT: ( 20, 21, 0, 0
+ // CHECK-NEXT: ( 30, 31, 0, 0
+ call @transfer_write_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+ call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
memref.dealloc %A : memref<?x?xf32>
return
>From 81e183b15e6f55f1e22dee87af0b37a36002cb39 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 8 Nov 2023 12:56:25 +0000
Subject: [PATCH 2/2] address comments
---
.../VectorToArmSME/VectorToArmSME.cpp | 20 ++++++++-----------
.../Dialect/ArmSME/vector-ops-to-sme.mlir | 2 +-
.../CPU/ArmSME/test-transfer-write-2d.mlir | 11 +++++-----
3 files changed, 15 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index a8956f0d38fba9d..953a465c18de69f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -171,10 +171,6 @@ struct TransferWriteToArmSMELowering
if (!arm_sme::isValidSMETileVectorType(vType))
return failure();
- assert(writeOp.getTransferRank() == 2 &&
- "expected a permutation_map with result dims of the same rank as "
- "the vector type");
-
if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
return failure();
@@ -183,20 +179,20 @@ struct TransferWriteToArmSMELowering
return rewriter.notifyMatchFailure(writeOp,
"not inbounds transfer write");
- arm_sme::TileSliceLayout layout;
-
AffineExpr d0, d1;
bindDims(writeOp.getContext(), d0, d1);
AffineMap map = writeOp.getPermutationMap();
- if (map.isIdentity())
- layout = arm_sme::TileSliceLayout::Horizontal;
- else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
- writeOp.getContext()))
- layout = arm_sme::TileSliceLayout::Vertical;
- else
+ bool isTranspose = (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
+ writeOp.getContext()));
+
+ if (!map.isIdentity() && !isTranspose)
return rewriter.notifyMatchFailure(writeOp,
"unsupported permutation map");
+ arm_sme::TileSliceLayout layout =
+ isTranspose ? arm_sme::TileSliceLayout::Vertical
+ : arm_sme::TileSliceLayout::Horizontal;
+
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
writeOp.getMask(), layout);
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index a1ad25ed77aa8ef..ae3b260da83a2c1 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -436,7 +436,7 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
// CHECK-NOT: arm_sme.tile_store
func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
- vector.transfer_write %vector, %dest[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [false, false]} : vector<[4]x[4]xf32>, memref<?x?xf32>
return
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index 174cec857437a7d..49c513badb7b071 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -32,7 +32,7 @@ func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: i
return
}
-// Vector store + transpose.
+// Vector transpose + store.
func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
@@ -40,7 +40,7 @@ func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %ba
return
}
-// Masked vector store + transpose.
+// Vector transpose + masked store.
func.func @transfer_write_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
@@ -135,7 +135,7 @@ func.func @entry() {
call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
- // 4. Reload 3. + store + transpose.
+ // 4. Reload 3. + transpose + store.
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 0, 20, 30
// CHECK-NEXT: ( 0, 0, 21, 31
@@ -144,8 +144,9 @@ func.func @entry() {
call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
- // 5. Reload 4. + store + transpose but with mask (nrows=4, ncols=2).
- // The mask applies after permutation
+ // 5. Reload 4. + transpose + masked store (nrows=4, ncols=2).
+ // The mask applies after permutation. Columns 2 and 3 (from 4.) are
+ // preserved.
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 0, 20, 30
// CHECK-NEXT: ( 0, 0, 21, 31
More information about the Mlir-commits
mailing list