[llvm-branch-commits] [mlir] [mlir][ptr] Extend `ptr_add` operation to support shaped operands (PR #156374)

Fabian Mora via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Sep 3 07:45:30 PDT 2025


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/156374

>From af522ed2b48cee2fe81901f2396025d58341997b Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Mon, 1 Sep 2025 21:05:55 +0000
Subject: [PATCH] extend ptr_add op

---
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h     |   1 +
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td    | 103 +++++++++++-------
 mlir/lib/Dialect/Ptr/IR/CMakeLists.txt        |   1 +
 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp        |  40 +++++++
 .../Conversion/PtrToLLVM/ptr-to-llvm.mlir     |  12 +-
 mlir/test/Dialect/Ptr/invalid.mlir            |  16 +++
 mlir/test/Dialect/Ptr/ops.mlir                |  65 +++++++++++
 mlir/test/Target/LLVMIR/ptr.mlir              |  30 +++++
 8 files changed, 225 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
index 8686cc7d316d4..eaf1e6243a74d 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Ptr/IR/PtrDialect.h"
 #include "mlir/Dialect/Ptr/IR/PtrTypes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 5939c3646884c..3ac12978b947c 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td"
 include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td"
 include "mlir/Dialect/Ptr/IR/PtrEnums.td"
 include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/OpAsmInterface.td"
@@ -34,8 +35,15 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
     /*descr=*/[{A shaped type with value semantics and rank.}],
     /*cppType=*/"::mlir::ShapedType">;
 
-// A shaped pointer type with value semantics and rank.
-class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
+// A ptr-like type, either scalar or shaped type with value semantics.
+def Ptr_PtrLikeType : 
+  AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
+
+// An int-like type, either scalar or shaped type with value semantics.
+def Ptr_IntLikeType :AnyTypeOf<[
+  Ptr_ShapedValueType<[AnySignlessIntegerOrIndex], [HasRankPred]>,
+  AnySignlessIntegerOrIndex
+]>;
 
 // A shaped value type of rank 1 of any element type.
 def Ptr_Any1DType :
@@ -167,41 +175,6 @@ def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// PtrAddOp
-//===----------------------------------------------------------------------===//
-
-def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
-    Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface
-  ]> {
-  let summary = "Pointer add operation";
-  let description = [{
-    The `ptr_add` operation adds an integer offset to a pointer to produce a new
-    pointer. The input and output pointer types are always the same.
-
-    Example:
-
-    ```mlir
-    %x_off  = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
-    %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
-    ```
-  }];
-
-  let arguments = (ins
-    Ptr_PtrType:$base,
-    AnySignlessIntegerOrIndex:$offset,
-    DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
-  let results = (outs Ptr_PtrType:$result);
-  let assemblyFormat = [{
-    ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
-  }];
-  let hasFolder = 1;
-  let extraClassDeclaration = [{
-    /// `ViewLikeOp::getViewSource` method. 
-    Value getViewSource() { return getBase(); }
-  }];
-}
-
 //===----------------------------------------------------------------------===//
 // LoadOp
 //===----------------------------------------------------------------------===//
@@ -361,6 +334,62 @@ def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// PtrAddOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
+    Pure, ViewLikeOpInterface,
+    DeclareOpInterfaceMethods<InferTypeOpInterface>
+  ]> {
+  let summary = "Pointer add operation";
+  let description = [{
+    The `ptr_add` operation adds an int-like offset to one or more pointers to produce one or more new pointers.
+
+    The operation supports both scalar and shaped types with value semantics:
+    - When both base and offset are scalar: produces a single new pointer
+    - When base is shaped and offset is scalar: adds the same offset to each
+    pointer in the base
+    - When base is scalar and offset is shaped: adds the single pointer to each
+    offset in the shaped value
+    - When both are shaped: performs element-wise addition (shapes must be
+    compatible)
+
+    Example:
+
+    ```mlir
+    // Scalar base and offset
+    %x_off  = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+    %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+    
+    // Shaped base with scalar offset
+    %ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32
+    
+    // Scalar base with shaped offset
+    %x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>
+    
+    // Both base and offset are shaped
+    %ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
+    ```
+  }];
+  let arguments = (ins
+    Ptr_PtrLikeType:$base,
+    Ptr_IntLikeType:$offset,
+    DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
+  let results = (outs Ptr_PtrLikeType:$result);
+  let assemblyFormat = [{
+    ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
+  }];
+  let hasFolder = 1;
+  let extraClassDeclaration = [{
+    /// `ViewLikeOp::getViewSource` method. 
+    Value getViewSource() { return getBase(); }
+
+    /// Returns the ptr type of the operation.
+    ptr::PtrType getPtrType();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
index bd1e655fc6b5e..a6b0d416a4165 100644
--- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -33,6 +33,7 @@ add_mlir_dialect_library(
   MLIRIR
   MLIRDataLayoutInterfaces
   MLIRMemorySlotInterfaces
+  MLIRInferTypeOpInterface
   MLIRViewLikeInterface
   MLIRPtrMemorySpaceInterfaces
 )
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 92ce9be97dd2c..6697f5382db6b 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -346,6 +346,46 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
   return nullptr;
 }
 
+LogicalResult PtrAddOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  // Get the base pointer and offset types.
+  Type baseType = operands[0].getType();
+  Type offsetType = operands[1].getType();
+
+  // If neither are shaped types, result is same as base type.
+  if (!isa<ShapedType>(baseType) && !isa<ShapedType>(offsetType)) {
+    inferredReturnTypes.push_back(baseType);
+    return success();
+  }
+
+  // Handle cases with shaped types.
+  if (auto baseTy = dyn_cast<ShapedType>(baseType)) {
+    // If both shaped, they must have the same shape.
+    if (auto offTy = dyn_cast<ShapedType>(offsetType)) {
+      if (offTy.getShape() != baseTy.getShape()) {
+        if (location)
+          mlir::emitError(*location) << "shapes of base and offset must match";
+        return failure();
+      }
+      // Make sure they are the same kind of shaped type.
+      if (baseType.getTypeID() != offsetType.getTypeID()) {
+        if (location)
+          mlir::emitError(*location) << "the shaped containers type must match";
+        return failure();
+      }
+    }
+    inferredReturnTypes.push_back(baseType);
+    return success();
+  }
+
+  // Base is scalar, offset is shaped.
+  auto offsetShapedType = cast<ShapedType>(offsetType);
+  inferredReturnTypes.push_back(offsetShapedType.clone(baseType));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ToPtrOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
index dc645fe0480fa..5128fd8ccb265 100644
--- a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
+++ b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
@@ -16,10 +16,10 @@
 // CHECK:           llvm.return %[[VAL_8]] : !llvm.struct<(ptr, ptr, ptr, ptr)>
 // CHECK:         }
 func.func @test_ptr_add(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: index) -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>) {
-  %0 = ptr.ptr_add %arg0, %arg1 : <#ptr.generic_space>, index
-  %1 = ptr.ptr_add nusw %arg0, %arg1 : <#ptr.generic_space>, index
-  %2 = ptr.ptr_add nuw %arg0, %arg1 : <#ptr.generic_space>, index
-  %3 = ptr.ptr_add inbounds %arg0, %arg1 : <#ptr.generic_space>, index
+  %0 = ptr.ptr_add %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+  %1 = ptr.ptr_add nusw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+  %2 = ptr.ptr_add nuw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+  %3 = ptr.ptr_add inbounds %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
   return %0, %1, %2, %3 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>
 }
 
@@ -263,7 +263,7 @@ func.func @test_comprehensive_dynamic(%arg0: memref<?x?xf32, strided<[?, ?], off
   %0 = ptr.to_ptr %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space> -> <#ptr.generic_space>
   %1 = ptr.get_metadata %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
   %2 = ptr.type_offset f32 : index
-  %3 = ptr.ptr_add inbounds %0, %2 : <#ptr.generic_space>, index
+  %3 = ptr.ptr_add inbounds %0, %2 : !ptr.ptr<#ptr.generic_space>, index
   %4 = ptr.from_ptr %3 metadata %1 : <#ptr.generic_space> -> memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
   return %4 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
 }
@@ -313,6 +313,6 @@ func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_s
   %0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space>
   %1 = ptr.type_offset f32 : index
   %2 = arith.muli %1, %arg1 : index
-  %3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index
+  %3 = ptr.ptr_add %0, %2 : !ptr.ptr<#ptr.generic_space>, index
   return %3 : !ptr.ptr<#ptr.generic_space>
 }
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index 0c34ae43bf6be..cc1eeb3cb5744 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -54,3 +54,19 @@ func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<1>>, %arg1: memref<f32
   ptr.store %arg1, %arg0 : memref<f32>, !ptr.ptr<#llvm.address_space<1>>
   return
 }
+
+// -----
+
+func.func @ptr_add_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  // expected-error at +1 {{the shaped containers type must match}}
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, vector<8xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+// -----
+
+func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  // expected-error at +1 {{shapes of base and offset must match}}
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index bde2fb22b6aa0..c008b858af0d7 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -11,6 +11,8 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#
   return %res : !ptr.ptr<#ptr.generic_space>
 }
 
+
+
 /// Check cast ops assembly.
 func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
   %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
@@ -126,3 +128,66 @@ func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector
   ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<4xf32>, !ptr.ptr<#llvm.address_space<3>>
   return %0 : vector<4xf32>
 }
+
+/// Test ptr_add with shaped operands (vectors)
+func.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res0 = ptr.ptr_add none %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res1 = ptr.ptr_add nusw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res2 = ptr.ptr_add nuw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res3 = ptr.ptr_add inbounds %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped operands (tensors)
+func.func @ptr_add_tensor(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<8xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with 2D tensors
+func.func @ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
+  %res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
+  return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with scalar base and shaped offsets (vectors)
+func.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with scalar base and shaped offsets (tensors)
+func.func @ptr_add_scalar_base_tensor_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped base and scalar offset (vectors)
+func.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offset: index) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res0 = ptr.ptr_add none %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res1 = ptr.ptr_add nusw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res2 = ptr.ptr_add nuw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res3 = ptr.ptr_add inbounds %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped base and scalar offset (tensors)
+func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offset: i64) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res0 = ptr.ptr_add none %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res1 = ptr.ptr_add nusw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res2 = ptr.ptr_add nuw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 545bec5979b2d..4b29be813fa81 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -203,3 +203,33 @@ llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>
   ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>>
   llvm.return
 }
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_vector
+// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], <4 x i32> %[[OFFSETS:.*]]) {
+// CHECK-NEXT:   %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], <4 x i32> %[[OFFSETS]]
+// CHECK-NEXT:   ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+  %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#llvm.address_space<0>>>, vector<4xi32>
+  llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_scalar_base_vector_offsets
+// CHECK-SAME: (ptr %[[PTR:.*]], <4 x i32> %[[OFFSETS:.*]]) {
+// CHECK-NEXT:   %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], <4 x i32> %[[OFFSETS]]
+// CHECK-NEXT:   ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#llvm.address_space<0>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+  %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#llvm.address_space<0>>, vector<4xi32>
+  llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_vector_base_scalar_offset
+// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], i32 %[[OFFSET:.*]]) {
+// CHECK-NEXT:   %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], i32 %[[OFFSET]]
+// CHECK-NEXT:   ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offset: i32) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+  %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#llvm.address_space<0>>>, i32
+  llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}



More information about the llvm-branch-commits mailing list