[Mlir-commits] [mlir] 1e7d6d3 - [mlir][vector] Propagate scalability to gather/scatter ptrs vector (#97584)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 9 01:06:28 PDT 2024


Author: Cullen Rhodes
Date: 2024-07-09T09:06:25+01:00
New Revision: 1e7d6d345518de0738eb5b2eb71addd499fea0fb

URL: https://github.com/llvm/llvm-project/commit/1e7d6d345518de0738eb5b2eb71addd499fea0fb
DIFF: https://github.com/llvm/llvm-project/commit/1e7d6d345518de0738eb5b2eb71addd499fea0fb.diff

LOG: [mlir][vector] Propagate scalability to gather/scatter ptrs vector (#97584)

In convert-vector-to-llvm the first operand (vector of pointers holding
all memory addresses to read) to the masked.gather (and scatter)
intrinsic has a fixed vector type.

This may result in intrinsics where the scalable flag has been dropped:
```
  %0 = llvm.intr.masked.gather %1, %2, %3 {alignment = 4 : i32}
    : (!llvm.vec<4 x ptr>, vector<[4]xi1>, vector<[4]xi32>) -> vector<[4]xi32>
```
Fortunately the operand is overloaded on the result type so we end up
with the correct IR when lowering to LLVM, but this is still incorrect.
This patch fixes it by propagating scalability.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 57af89f5dbf8d..2e1635e590cad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -900,6 +900,8 @@ def LLVM_masked_gather : LLVM_OneResultIntrOp<"masked.gather"> {
       $_resultType, $ptrs, $mask, $pass_thru, $_int_attr($alignment));
   }];
   list<int> llvmArgIndices = [0, 2, 3, 1];
+
+  let hasVerifier = 1;
 }
 
 /// Create a call to Masked Scatter intrinsic.
@@ -919,6 +921,8 @@ def LLVM_masked_scatter : LLVM_ZeroResultIntrOp<"masked.scatter"> {
       $value, $ptrs, $mask, $_int_attr($alignment));
   }];
   list<int> llvmArgIndices = [0, 1, 3, 2];
+
+  let hasVerifier = 1;
 }
 
 /// Create a call to Masked Expand Load intrinsic.

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6a8a9d818aad2..f6b1c42dcd24c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -102,11 +102,14 @@ static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
                             const LLVMTypeConverter &typeConverter,
                             MemRefType memRefType, Value llvmMemref, Value base,
-                            Value index, uint64_t vLen) {
+                            Value index, VectorType vectorType) {
   assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
          "unsupported memref type");
+  assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
   auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
-  auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
+  auto ptrsType =
+      LLVM::getVectorType(pType, vectorType.getDimSize(0),
+                          /*isScalable=*/vectorType.getScalableDims()[0]);
   return rewriter.create<LLVM::GEPOp>(
       loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
       base, index);
@@ -288,9 +291,9 @@ class VectorGatherOpConversion
     if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
       auto vType = gather.getVectorType();
       // Resolve address.
-      Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
-                                  memRefType, base, ptr, adaptor.getIndexVec(),
-                                  /*vLen=*/vType.getDimSize(0));
+      Value ptrs =
+          getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+                         base, ptr, adaptor.getIndexVec(), vType);
       // Replace with the gather intrinsic.
       rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
           gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
@@ -305,8 +308,7 @@ class VectorGatherOpConversion
       // Resolve address.
       Value ptrs = getIndexedPtrs(
           rewriter, loc, typeConverter, memRefType, base, ptr,
-          /*index=*/vectorOperands[0],
-          LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
+          /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
       // Create the gather intrinsic.
       return rewriter.create<LLVM::masked_gather>(
           loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
@@ -343,9 +345,9 @@ class VectorScatterOpConversion
     VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    Value ptrs = getIndexedPtrs(
-        rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
-        ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
+    Value ptrs =
+        getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+                       adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
 
     // Replace with the scatter intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d1280aceeb7b6..a01c4ee4923eb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3076,6 +3076,40 @@ void InlineAsmOp::getEffects(
   }
 }
 
+//===----------------------------------------------------------------------===//
+// masked_gather (intrinsic)
+//===----------------------------------------------------------------------===//
+
+LogicalResult LLVM::masked_gather::verify() {
+  auto ptrsVectorType = getPtrs().getType();
+  Type expectedPtrsVectorType =
+      LLVM::getVectorType(extractVectorElementType(ptrsVectorType),
+                          LLVM::getVectorNumElements(getRes().getType()));
+  // Vector of pointers type should match result vector type, other than the
+  // element type.
+  if (ptrsVectorType != expectedPtrsVectorType)
+    return emitOpError("expected operand #1 type to be ")
+           << expectedPtrsVectorType;
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// masked_scatter (intrinsic)
+//===----------------------------------------------------------------------===//
+
+LogicalResult LLVM::masked_scatter::verify() {
+  auto ptrsVectorType = getPtrs().getType();
+  Type expectedPtrsVectorType =
+      LLVM::getVectorType(extractVectorElementType(ptrsVectorType),
+                          LLVM::getVectorNumElements(getValue().getType()));
+  // Vector of pointers type should match value vector type, other than the
+  // element type.
+  if (ptrsVectorType != expectedPtrsVectorType)
+    return emitOpError("expected operand #2 type to be ")
+           << expectedPtrsVectorType;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LLVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 5f2d2809a0fe8..c310954b906e4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2248,6 +2248,19 @@ func.func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3
 
 // -----
 
+func.func @gather_op_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) -> vector<[3]xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
+  return %1 : vector<[3]xf32>
+}
+
+// CHECK-LABEL: func @gather_op_scalable
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: return %[[G]] : vector<[3]xf32>
+
+// -----
+
 func.func @gather_op_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
   %0 = arith.constant 0: index
   %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32, 1>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
@@ -2351,6 +2364,18 @@ func.func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<
 
 // -----
 
+func.func @scatter_op_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) {
+  %0 = arith.constant 0: index
+  vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_op_scalable
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
+
+// -----
+
 func.func @scatter_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) {
   %0 = arith.constant 0: index
   vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xindex>, vector<3xi1>, vector<3xindex>

diff  --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 40f2260574bf5..9cf922ad490a9 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -261,6 +261,14 @@ llvm.func @masked_gather_intr_wrong_type(%ptrs : vector<7xf32>, %mask : vector<7
 
 // -----
 
+llvm.func @masked_gather_intr_wrong_type_scalable(%ptrs : !llvm.vec<7 x ptr>, %mask : vector<[7]xi1>) -> vector<[7]xf32> {
+  // expected-error @below{{expected operand #1 type to be '!llvm.vec<? x 7 x  ptr>'}}
+  %0 = llvm.intr.masked.gather %ptrs, %mask { alignment = 1: i32} : (!llvm.vec<7 x ptr>, vector<[7]xi1>) -> vector<[7]xf32>
+  llvm.return %0 : vector<[7]xf32>
+}
+
+// -----
+
 llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : !llvm.vec<7xptr>, %mask : vector<7xi1>) {
   // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
   llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : f32, vector<7xi1> into !llvm.vec<7xptr>
@@ -269,6 +277,14 @@ llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : !llvm.vec<7xptr>,
 
 // -----
 
+llvm.func @masked_scatter_intr_wrong_type_scalable(%vec : vector<[7]xf32>, %ptrs : !llvm.vec<7xptr>, %mask : vector<[7]xi1>) {
+  // expected-error @below{{expected operand #2 type to be '!llvm.vec<? x 7 x  ptr>'}}
+  llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : vector<[7]xf32>, vector<[7]xi1> into !llvm.vec<7xptr>
+  llvm.return
+}
+
+// -----
+
 llvm.func @stepvector_intr_wrong_type() -> vector<7xf32> {
   // expected-error @below{{op result #0 must be LLVM dialect-compatible vector of signless integer, but got 'vector<7xf32>'}}
   %0 = llvm.intr.experimental.stepvector : vector<7xf32>


        


More information about the Mlir-commits mailing list