[Mlir-commits] [mlir] [MLIR][Vector] Allow non-default memory spaces in gather/scatter lowerings (PR #67500)

Quinn Dawkins llvmlistbot at llvm.org
Thu Sep 28 13:30:13 PDT 2023


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/67500

>From d9264e6e229494b0e03ae958fa6ceb7d946e4703 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Tue, 26 Sep 2023 19:07:11 -0400
Subject: [PATCH] [Vector] Allow non-default memory spaces in gather/scatter
 lowerings

GPU targets can gather on non-default address spaces (e.g. global),
so this removes the check for the default memory space.
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp           |  6 ++----
 .../Conversion/VectorToLLVM/vector-to-llvm.mlir    | 14 ++++++++++++++
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3f77c5b5f24e9b5..8427d60f14c0bcc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -87,14 +87,12 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
   return success();
 }
 
-// Check if the last stride is non-unit or the memory space is not zero.
+// Check if the last stride is non-unit and has a valid memory space.
 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
                                            const LLVMTypeConverter &converter) {
   if (!isLastMemrefDimUnitStride(memRefType))
     return failure();
-  FailureOr<unsigned> addressSpace =
-      converter.getMemRefAddressSpace(memRefType);
-  if (failed(addressSpace) || *addressSpace != 0)
+  if (failed(converter.getMemRefAddressSpace(memRefType)))
     return failure();
   return success();
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9f7581211fd828d..9aa4d735681f576 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2108,6 +2108,20 @@ func.func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3
 
 // -----
 
+func.func @gather_op_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32, 1>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %1 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @gather_op
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: return %[[G]] : vector<3xf32>
+
+// -----
+
+
 func.func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) -> vector<3xindex> {
   %0 = arith.constant 0: index
   %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xindex>, vector<3xi1>, vector<3xindex> into vector<3xindex>



More information about the Mlir-commits mailing list