[Mlir-commits] [mlir] 78c4974 - [MLIR][Vector] Allow non-default memory spaces in gather/scatter lowerings (#67500)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 28 16:20:36 PDT 2023

Author: Quinn Dawkins
Date: 2023-09-28T19:20:32-04:00
New Revision: 78c49743c7fb2353b1ca56d56f9ffda7f47e4bae

URL: https://github.com/llvm/llvm-project/commit/78c49743c7fb2353b1ca56d56f9ffda7f47e4bae
DIFF: https://github.com/llvm/llvm-project/commit/78c49743c7fb2353b1ca56d56f9ffda7f47e4bae.diff

LOG: [MLIR][Vector] Allow non-default memory spaces in gather/scatter lowerings (#67500)

GPU targets can gather on non-default address spaces (e.g. global), so
this removes the check for the default memory space.




diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3f77c5b5f24e9b5..8427d60f14c0bcc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -87,14 +87,12 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
   return success();
-// Check if the last stride is non-unit or the memory space is not zero.
+// Check if the last stride is non-unit and has a valid memory space.
 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
                                            const LLVMTypeConverter &converter) {
   if (!isLastMemrefDimUnitStride(memRefType))
     return failure();
-  FailureOr<unsigned> addressSpace =
-      converter.getMemRefAddressSpace(memRefType);
-  if (failed(addressSpace) || *addressSpace != 0)
+  if (failed(converter.getMemRefAddressSpace(memRefType)))
     return failure();
   return success();

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9f7581211fd828d..9aa4d735681f576 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2108,6 +2108,20 @@ func.func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3
 // -----
+func.func @gather_op_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32, 1>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %1 : vector<3xf32>
+// CHECK-LABEL: func @gather_op
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: return %[[G]] : vector<3xf32>
+// -----
 func.func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) -> vector<3xindex> {
   %0 = arith.constant 0: index
   %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xindex>, vector<3xi1>, vector<3xindex> into vector<3xindex>


More information about the Mlir-commits mailing list