[Mlir-commits] [mlir] [mlir][vector] Propagate scalability to gather/scatter ptrs vector (PR #97584)
Cullen Rhodes
llvmlistbot at llvm.org
Wed Jul 3 07:38:45 PDT 2024
https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/97584
In convert-vector-to-llvm the first operand (vector of pointers holding
all memory addresses to read) to the masked.gather (and scatter)
intrinsic has a fixed vector type.
This may result in intrinsics where the scalable flag has been dropped:
```
%0 = llvm.intr.masked.gather %1, %2, %3 {alignment = 4 : i32}
: (!llvm.vec<4 x ptr>, vector<[4]xi1>, vector<[4]xi32>) -> vector<[4]xi32>
```
Fortunately the operand is overloaded on the result type so we end up
with the correct IR when lowering to LLVM, but this is still incorrect.
This patch fixes it by propagating scalability.
>From f6977f427933937b3298119c5a24ffab7e448135 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 3 Jul 2024 14:20:36 +0000
Subject: [PATCH 1/2] [mlir][vector] add conversion tests for scalable
gather/scatter
---
.../VectorToLLVM/vector-to-llvm.mlir | 25 +++++++++++++++++++
1 file changed, 25 insertions(+)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 09b79708a9ab2..7de0285baad89 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2248,6 +2248,19 @@ func.func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3
// -----
+func.func @gather_op_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) -> vector<[3]xf32> {
+ %0 = arith.constant 0: index
+ %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
+ return %1 : vector<[3]xf32>
+}
+
+// CHECK-LABEL: func @gather_op_scalable
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<3 x ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: return %[[G]] : vector<[3]xf32>
+
+// -----
+
func.func @gather_op_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
%0 = arith.constant 0: index
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32, 1>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
@@ -2351,6 +2364,18 @@ func.func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<
// -----
+func.func @scatter_op_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) {
+ %0 = arith.constant 0: index
+ vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_op_scalable
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<3 x ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<3 x ptr>
+
+// -----
+
func.func @scatter_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) {
%0 = arith.constant 0: index
vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xindex>, vector<3xi1>, vector<3xindex>
>From 704b9e1a7ecf568bb03601abd5fbfde810805cd2 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 3 Jul 2024 14:23:14 +0000
Subject: [PATCH 2/2] [mlir][vector] Propagate scalability to gather/scatter
ptrs vector
In convert-vector-to-llvm the first operand (vector of pointers holding
all memory addresses to read) to the masked.gather (and scatter)
intrinsic has a fixed vector type.
This may result in intrinsics where the scalable flag has been dropped:
%0 = llvm.intr.masked.gather %1, %2, %3 {alignment = 4 : i32}
: (!llvm.vec<4 x ptr>, vector<[4]xi1>, vector<[4]xi32>) -> vector<[4]xi32>
Fortunately the operand is overloaded on the result type so we end up
with the correct IR when lowering to LLVM, but this is still incorrect.
This patch fixes it by propagating scalability.
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 21 ++++++++++---------
.../VectorToLLVM/vector-to-llvm.mlir | 8 +++----
2 files changed, 15 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 0eac55255b133..77bdacbc46990 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -102,11 +102,13 @@ static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter &typeConverter,
MemRefType memRefType, Value llvmMemref, Value base,
- Value index, uint64_t vLen) {
+ Value index, VectorType vectorType) {
assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
"unsupported memref type");
auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
- auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
+ auto ptrsType =
+ LLVM::getVectorType(pType, vectorType.getDimSize(0),
+ /*isScalable=*/vectorType.getScalableDims()[0]);
return rewriter.create<LLVM::GEPOp>(
loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
base, index);
@@ -288,9 +290,9 @@ class VectorGatherOpConversion
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
auto vType = gather.getVectorType();
// Resolve address.
- Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
- memRefType, base, ptr, adaptor.getIndexVec(),
- /*vLen=*/vType.getDimSize(0));
+ Value ptrs =
+ getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+ base, ptr, adaptor.getIndexVec(), vType);
// Replace with the gather intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
@@ -305,8 +307,7 @@ class VectorGatherOpConversion
// Resolve address.
Value ptrs = getIndexedPtrs(
rewriter, loc, typeConverter, memRefType, base, ptr,
- /*index=*/vectorOperands[0],
- LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
+ /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
// Create the gather intrinsic.
return rewriter.create<LLVM::masked_gather>(
loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
@@ -343,9 +344,9 @@ class VectorScatterOpConversion
VectorType vType = scatter.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
- Value ptrs = getIndexedPtrs(
- rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
- ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
+ Value ptrs =
+ getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+ adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7de0285baad89..6f8145b618b71 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2255,8 +2255,8 @@ func.func @gather_op_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg
}
// CHECK-LABEL: func @gather_op_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: return %[[G]] : vector<[3]xf32>
// -----
@@ -2371,8 +2371,8 @@ func.func @scatter_op_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %ar
}
// CHECK-LABEL: func @scatter_op_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<3 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
// -----
More information about the Mlir-commits
mailing list