[Mlir-commits] [mlir] [MLIR][Vector] Allow non-default memory spaces in gather/scatter lowerings (PR #67500)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 26 16:20:37 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

GPU targets can gather on non-default address spaces (e.g. global), so this removes the check for the default memory space.

---
Full diff: https://github.com/llvm/llvm-project/pull/67500.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+2-4) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+14) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3f77c5b5f24e9b5..8427d60f14c0bcc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -87,14 +87,12 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
   return success();
 }
 
-// Check if the last stride is non-unit or the memory space is not zero.
+// Check if the last stride is non-unit and has a valid memory space.
 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
                                            const LLVMTypeConverter &converter) {
   if (!isLastMemrefDimUnitStride(memRefType))
     return failure();
-  FailureOr<unsigned> addressSpace =
-      converter.getMemRefAddressSpace(memRefType);
-  if (failed(addressSpace) || *addressSpace != 0)
+  if (failed(converter.getMemRefAddressSpace(memRefType)))
     return failure();
   return success();
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 41ab06f2e23b501..0d075bca2e3c45c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2108,6 +2108,20 @@ func.func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3
 
 // -----
 
+func.func @gather_op_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32, 1>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %1 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @gather_op
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: return %[[G]] : vector<3xf32>
+
+// -----
+
+
 func.func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) -> vector<3xindex> {
   %0 = arith.constant 0: index
   %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xindex>, vector<3xi1>, vector<3xindex> into vector<3xindex>

``````````

</details>


https://github.com/llvm/llvm-project/pull/67500


More information about the Mlir-commits mailing list