[Mlir-commits] [mlir] [Vector] Allow non-default memory spaces in gather/scatter lowerings (PR #67500)
Quinn Dawkins
llvmlistbot at llvm.org
Tue Sep 26 16:19:32 PDT 2023
https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/67500
GPU targets can gather on non-default address spaces (e.g. global), so this removes the check for the default memory space.
>From d81195593e41e79fc37d1b23286d2741700310e4 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Tue, 26 Sep 2023 19:07:11 -0400
Subject: [PATCH] [Vector] Allow non-default memory spaces in gather/scatter
lowerings
GPU targets can gather on non-default address spaces (e.g. global),
so this removes the check for the default memory space.
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 6 ++----
.../Conversion/VectorToLLVM/vector-to-llvm.mlir | 14 ++++++++++++++
2 files changed, 16 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3f77c5b5f24e9b5..8427d60f14c0bcc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -87,14 +87,12 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
return success();
}
-// Check if the last stride is non-unit or the memory space is not zero.
+// Check if the last stride is non-unit and has a valid memory space.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
const LLVMTypeConverter &converter) {
if (!isLastMemrefDimUnitStride(memRefType))
return failure();
- FailureOr<unsigned> addressSpace =
- converter.getMemRefAddressSpace(memRefType);
- if (failed(addressSpace) || *addressSpace != 0)
+ if (failed(converter.getMemRefAddressSpace(memRefType)))
return failure();
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 41ab06f2e23b501..0d075bca2e3c45c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2108,6 +2108,20 @@ func.func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3
// -----
+func.func @gather_op_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+ %0 = arith.constant 0: index
+ %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32, 1>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+ return %1 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @gather_op
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: return %[[G]] : vector<3xf32>
+
+// -----
+
+
func.func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) -> vector<3xindex> {
%0 = arith.constant 0: index
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xindex>, vector<3xi1>, vector<3xindex> into vector<3xindex>
More information about the Mlir-commits
mailing list