[Mlir-commits] [mlir] [mlir][ptr] Add high-level read and write memory operations (PR #161081)

Fabian Mora llvmlistbot at llvm.org
Sun Sep 28 06:46:11 PDT 2025


https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/161081

Add `ptr.read` and `ptr.write` operations to the pointer dialect. These operations
provide a high-level interface for reading from and writing to memory with:

- Masked access semantics (conditional loads/stores)
- Contiguity information for optimized lowering
- Support for both vector and tensor types
- Ability to express row-major, column-major, and gather/scatter patterns

It's future work to add lowerings to ptr.load/store, ptr.masked_load/store, or
ptr.gather/scatter depending on mask and contiguity properties.

Example:
```mlir
func.func @read(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) {
  // Row-major styled read
  %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
  // Column-major styled read
  %1 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [4, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
  // Gather styled read
  %2 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [1, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
  return
}

func.func @write(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
  // Row-major styled write
  ptr.write %value, %ptr, %mask contiguity = [1, 4] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
  // Column-major styled write
  ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
  // Scatter styled write
  ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
  return
}
```

>From 2b8c89d36e3e4576b6d49778bc2c0df8c2faeb33 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sun, 28 Sep 2025 13:41:26 +0000
Subject: [PATCH] add read-write ops

---
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 248 +++++++++++++++++++++
 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp     | 105 +++++++++
 mlir/test/Dialect/Ptr/invalid.mlir         |  48 ++++
 mlir/test/Dialect/Ptr/ops.mlir             |  36 +++
 4 files changed, 437 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index e14f64330c294..c3a5415d0cbc8 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -24,6 +24,8 @@ include "mlir/IR/OpAsmInterface.td"
 
 def AlignmentProp : OptionalProp<I64Prop>;
 
+def ContiguityProp : IntArrayProp<I32Prop, "memory access contiguity information">;
+
 //===----------------------------------------------------------------------===//
 // Common types
 //===----------------------------------------------------------------------===//
@@ -45,6 +47,15 @@ def Ptr_IntLikeType :AnyTypeOf<[
   AnySignlessIntegerOrIndex
 ]>;
 
+// A shaped pointer type with value semantics.
+def Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
+
+// A shaped mask type with value semantics.
+def Ptr_ShapedMaskType : Ptr_ShapedValueType<[I1], [HasRankPred]>;
+
+// A shaped mask type with value semantics.
+def Ptr_ShapedAnyType : Ptr_ShapedValueType<[AnyType], [HasRankPred]>;
+
 // A shaped value type of rank 1 of any element type.
 def Ptr_Any1DType :
   Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
@@ -472,6 +483,127 @@ def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ReadOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ReadOp : Pointer_Op<"read", [
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
+      ::llvm::cast<ShapedType>($_self).clone(
+        IntegerType::get($_self.getContext(), 1))
+    }]>,
+    AllTypesMatch<["result", "passthrough"]>,
+    // Check the shapes are compatible and both use the same shaped container
+    // type.
+    AllShapesMatch<["result", "ptr"]>, AllTypeIDsMatch<["result", "ptr"]>
+  ]> {
+  let summary = "Read operation";
+  let description = [{
+    The `read` operation is a high-level operation that performs a read
+    from multiple memory locations specified by `ptr` based on a mask `mask`.
+    Elements of the `result`, corresponding to masked-off lanes, are taken from
+    the `passthrough` operand.
+
+    The `mask` operand is a shaped type of `i1` elements that must have the same
+    shape as the result type.
+
+    The `contiguity` property is an integer array with the same rank as `ptr`,
+    where each element describes memory access contiguity for the corresponding
+    dimension. The precise semantics of this property are given by:
+    Let `c1, c2, ..., cn` be the elements of the contiguity array, and
+    `s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
+    The following rules and restrictions apply:
+      1. `ck` must be strictly positive for all k.
+      2. `ck` must divide `sk` for all k.
+      3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
+         given by:
+           - `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, ji+1, ..., jn]` for
+           `l = 0, 1, ..., sk / ck - 1`
+         are contiguous for all k.
+
+    It is undefined behavior if the pointers in `ptr` do not satisfy the
+    contiguity constraints specified by `contiguity`.
+
+    Depending on the values of `mask` and `contiguity`, the operation can be
+    lowered to either:
+    1. A `ptr.load`, if the mask is all ones, and there's a dimension where all
+       the accesses are contiguous.
+    2. A `ptr.masked_load`, if the mask is not all ones, and there's a dimension
+       where all the accesses are contiguous.
+    3. A `ptr.gather` if the mask is not all ones, and there's no contiguous
+       dimension.
+
+    The alignment property describes the alignment (in bytes) of each contiguous
+    memory-block being accessed.
+
+    Examples:
+    ```mlir
+    // Read a vector in row-major order
+    %result = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] :
+      vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+
+    // Read a vector in column-major order with alignment
+    %result = ptr.read %ptr, %mask, %passthrough alignment = 8
+      contiguity = [4, 1] :
+      vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+
+    // Gather a vector from memory
+    %result = ptr.read %ptr, %mask, %passthrough alignment = 8
+      contiguity = [1, 1] :
+      vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+    ```
+  }];
+  let arguments = (ins Ptr_ShapedPtrType:$ptr,
+                       Ptr_ShapedMaskType:$mask,
+                       Ptr_ShapedAnyType:$passthrough,
+                       AlignmentProp:$alignment,
+                       ContiguityProp:$contiguity);
+  let results = (outs Ptr_ShapedAnyType:$result);
+  let assemblyFormat = [{
+    $ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
+    `contiguity` `=` $contiguity attr-dict `:` type($ptr) `->` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$passthrough,
+      CArg<"unsigned", "0">:$alignment,
+      CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
+  ];
+  let hasVerifier = 1;
+  let extraClassDeclaration = [{
+    /// Returns the ptr type of the operation.
+    PtrType getPtrType()  {
+      return cast<PtrType>(getPtr().getType().getElementType());
+    }
+
+    /// Returns the rank of the shaped operands and result.
+    unsigned getRank() { return getType().getRank(); }
+
+    /// Returns the shape of the shaped operands and result.
+    ArrayRef<int64_t> getShape() { return getType().getShape(); }
+
+    /// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
+    /// of the `i`-th dimension.
+    std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
+      assert(i < getRank() && "Invalid dimension");
+      return {getContiguity()[i], getShape()[i]};
+    }
+
+    /// Returns true if the `i`-th dimension is contiguous.
+    bool isContiguous(unsigned i) {
+      auto [contiguity, size] = getContiguityInfo(i);
+      return contiguity == size && size > 1;
+    }
+
+    /// Returns true if the read has gather semantics, ie. there's no dimension
+    /// where all the accesses are contiguous.
+    bool hasGatherSemantics() {
+      return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
+        [this](unsigned i) { return isContiguous(i); });
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
@@ -645,4 +777,120 @@ def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// WriteOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_WriteOp : Pointer_Op<"write", [
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    TypesMatchWith<"value and mask must be compatible",
+                                          "value", "mask", [{
+      cast<ShapedType>($_self).clone(IntegerType::get($_self.getContext(), 1))
+    }]>,
+    // Check the shapes are compatible and both use the same shaped container
+    AllShapesMatch<["value", "ptr"]>, AllTypeIDsMatch<["value", "ptr"]>
+  ]> {
+  let summary = "Write operation";
+  let description = [{
+    The `write` operation is a high-level operation that performs a write to
+    multiple memory locations specified by `ptr` based on a mask `mask`.
+    Elements of the `value`, corresponding to masked-off lanes, are not written
+    to memory.
+
+    The `mask` operand is a shaped type of `i1` elements that must have the same
+    shape as the `value` type.
+
+    The `contiguity` property is an integer array with the same rank as `ptr`,
+    where each element describes memory access contiguity for the corresponding
+    dimension. The precise semantics of this property are given by:
+    Let `c1, c2, ..., cn` be the elements of the contiguity array, and
+    `s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
+    The following rules and restrictions apply:
+      1. `ck` must be strictly positive for all k.
+      2. `ck` must divide `sk` for all k.
+      3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
+         given by:
+           - `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, ji+1, ..., jn]` for
+           `l = 0, 1, ..., sk / ck - 1`
+         are contiguous for all k.
+
+    It is undefined behavior if the pointers in `ptr` do not satisfy the
+    contiguity constraints specified by `contiguity`.
+
+    Depending on the values of `mask` and `contiguity`, the operation can be
+    lowered to either:
+    1. A `ptr.store`, if the mask is all ones, and there's a dimension where all
+       the accesses are contiguous.
+    2. A `ptr.masked_store`, if the mask is not all ones, and there's a dimension
+       where all the accesses are contiguous.
+    3. A `ptr.scatter` if the mask is not all ones, and there's no contiguous
+       dimension.
+
+    The alignment property describes the alignment (in bytes) of each contiguous
+    memory-block being accessed.
+
+    Example:
+    ```mlir
+    // Write a vector in row-major order
+    ptr.write %value, %ptr, %mask contiguity = [1, 4] :
+      vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+
+    // Write a vector in column-major order with alignment
+    ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] :
+      vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+
+    // Scatter a vector to memory
+    ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] :
+      vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+    ```
+  }];
+  let arguments = (ins Ptr_ShapedAnyType:$value,
+                       Ptr_ShapedPtrType:$ptr,
+                       Ptr_ShapedMaskType:$mask,
+                       AlignmentProp:$alignment,
+                       ContiguityProp:$contiguity);
+  let assemblyFormat = [{
+    $value `,` $ptr `,` $mask (`alignment` `=` $alignment^)?
+    `contiguity` `=` $contiguity attr-dict `:` type($value) `,` type($ptr)
+  }];
+  let builders = [
+    OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask,
+      CArg<"unsigned", "0">:$alignment,
+      CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
+  ];
+  let hasVerifier = 1;
+  let extraClassDeclaration = [{
+    /// Returns the ptr type of the operation.
+    PtrType getPtrType()  {
+      return cast<PtrType>(getPtr().getType().getElementType());
+    }
+
+    /// Returns the rank of the shaped operands.
+    unsigned getRank() { return getPtr().getType().getRank(); }
+
+    /// Returns the shape of the shaped operands.
+    ArrayRef<int64_t> getShape() { return getPtr().getType().getShape(); }
+
+    /// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
+    /// of the `i`-th dimension.
+    std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
+      assert(i < getRank() && "Invalid dimension");
+      return {getContiguity()[i], getShape()[i]};
+    }
+
+    /// Returns true if the `i`-th dimension is contiguous.
+    bool isContiguous(unsigned i) {
+      auto [contiguity, size] = getContiguityInfo(i);
+      return contiguity == size && size > 1;
+    }
+
+    /// Returns true if the write has scatter semantics, ie. there's no
+    /// dimension where all the accesses are contiguous.
+    bool hasScatterSemantics() {
+      return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
+        [this](unsigned i) { return isContiguous(i); });
+    }
+  }];
+}
+
 #endif // PTR_OPS
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 51f25f755a8a6..ecfbd957bbe24 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -57,6 +57,25 @@ verifyAlignment(std::optional<int64_t> alignment,
   return success();
 }
 
+/// Verifies that the contiguity array has the right size, all the elements are
+/// positive and divide the corresponding shape dimension.
+static LogicalResult
+verifyContiguityProp(ArrayRef<int32_t> contiguity, ArrayRef<int64_t> shape,
+                     function_ref<InFlightDiagnostic()> emitError) {
+  if (contiguity.size() != shape.size()) {
+    return emitError() << "expected contiguity array with " << shape.size()
+                       << " elements";
+  }
+  if (!llvm::all_of(llvm::zip(contiguity, shape), [](auto cs) {
+        int32_t c = std::get<0>(cs);
+        return c > 0 && std::get<1>(cs) % c == 0;
+      })) {
+    return emitError()
+           << "expected contiguity values to be positive and divide the shape";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -264,6 +283,49 @@ void MaskedStoreOp::build(OpBuilder &builder, OperationState &state,
         alignment ? std::optional<int64_t>(alignment) : std::nullopt);
 }
 
+//===----------------------------------------------------------------------===//
+// ReadOp
+//===----------------------------------------------------------------------===//
+
+void ReadOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
+}
+
+LogicalResult ReadOp::verify() {
+  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+  // Verify that the pointer type's memory space allows loads.
+  MemorySpaceAttrInterface ms =
+      cast<PtrType>(getPtr().getType().getElementType()).getMemorySpace();
+  DataLayout dataLayout = DataLayout::closest(*this);
+  if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
+                      getAlignment(), &dataLayout, emitDiag))
+    return failure();
+
+  // Verify the alignment.
+  if (failed(verifyAlignment(getAlignment(), emitDiag)))
+    return failure();
+
+  // Verify the contiguity array.
+  return verifyContiguityProp(getContiguity(), getShape(), emitDiag);
+}
+
+void ReadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
+                   Value mask, Value passthrough, unsigned alignment,
+                   ArrayRef<int32_t> contiguity) {
+  if (!contiguity.empty()) {
+    build(builder, state, ptr, mask, passthrough,
+          alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+          contiguity);
+    return;
+  }
+  build(builder, state, ptr, mask, passthrough,
+        alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+        SmallVector<int32_t>(cast<ShapedType>(ptr.getType()).getRank(), 1));
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
@@ -470,6 +532,49 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
   return dl.getTypeSize(getElementType());
 }
 
+//===----------------------------------------------------------------------===//
+// WriteOp
+//===----------------------------------------------------------------------===//
+
+void WriteOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
+}
+
+LogicalResult WriteOp::verify() {
+  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+  // Verify that the pointer type's memory space allows stores.
+  MemorySpaceAttrInterface ms =
+      cast<PtrType>(getPtr().getType().getElementType()).getMemorySpace();
+  DataLayout dataLayout = DataLayout::closest(*this);
+  if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
+                       getAlignment(), &dataLayout, emitDiag))
+    return failure();
+
+  // Verify the alignment.
+  if (failed(verifyAlignment(getAlignment(), emitDiag)))
+    return failure();
+
+  // Verify the contiguity array.
+  return verifyContiguityProp(getContiguity(), getShape(), emitDiag);
+}
+
+void WriteOp::build(OpBuilder &builder, OperationState &state, Value value,
+                    Value ptr, Value mask, unsigned alignment,
+                    ArrayRef<int32_t> contiguity) {
+  if (!contiguity.empty()) {
+    build(builder, state, value, ptr, mask,
+          alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+          contiguity);
+    return;
+  }
+  build(builder, state, value, ptr, mask,
+        alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+        SmallVector<int32_t>(cast<ShapedType>(ptr.getType()).getRank(), 1));
+}
+
 //===----------------------------------------------------------------------===//
 // Pointer API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index 83e1c880650c5..54332a5632808 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -78,3 +78,51 @@ func.func @ptr_diff_mismatch(%lhs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %rhs:
   %res = ptr.ptr_diff %lhs, %rhs : tensor<8x!ptr.ptr<#ptr.generic_space>> -> vector<8xi64>
   return %res : vector<8xi64>
 }
+
+// -----
+
+func.func @read_contiguity_does_not_divide(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> {
+  // expected-error at +1 {{expected contiguity values to be positive and divide the shape}}
+  %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, 3] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// -----
+
+func.func @read_contiguity_is_not_positive(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> {
+  // expected-error at +1 {{expected contiguity values to be positive and divide the shape}}
+  %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, -1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// -----
+
+func.func @read_invalid_contiguity_size(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> {
+  // expected-error at +1 {{expected contiguity array with 2 elements}}
+  %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// -----
+
+func.func @write_contiguity_does_not_divide(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
+  // expected-error at +1 {{expected contiguity values to be positive and divide the shape}}
+  ptr.write %value, %ptr, %mask contiguity = [1, 7] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+  return
+}
+
+// -----
+
+func.func @write_contiguity_is_not_positive(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
+  // expected-error at +1 {{expected contiguity values to be positive and divide the shape}}
+  ptr.write %value, %ptr, %mask contiguity = [0, 4] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+  return
+}
+
+// -----
+
+func.func @write_invalid_contiguity_size(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
+  // expected-error at +1 {{expected contiguity array with 2 elements}}
+  ptr.write %value, %ptr, %mask contiguity = [1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+  return
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 0a906ad559e21..d0c0390d6932e 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -239,3 +239,39 @@ func.func @ptr_diff_tensor_2d_ops(%ptrs1: tensor<4x8x!ptr.ptr<#ptr.generic_space
   %diff = ptr.ptr_diff %ptrs1, %ptrs2 : tensor<4x8x!ptr.ptr<#ptr.generic_space>> -> tensor<4x8xi64>
   return %diff : tensor<4x8xi64>
 }
+
+/// Check read op assembly.
+func.func @read_ops(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) {
+  // Row-major styled read
+  %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+  // Column-major styled read
+  %1 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [4, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+  // Gather styled read
+  %2 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [1, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+  return
+}
+
+/// Check read op assembly with tensors
+func.func @read_ops_tensor(%ptr: tensor<8x!ptr.ptr<#ptr.generic_space>>, %mask: tensor<8xi1>, %passthrough: tensor<8xf32>) {
+  %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1] : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xf32>
+  %1 = ptr.read %ptr, %mask, %passthrough alignment = 4 contiguity = [8] : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xf32>
+  return
+}
+
+/// Check write op assembly.
+func.func @write_ops(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
+  // Row-major styled write
+  ptr.write %value, %ptr, %mask contiguity = [1, 4] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+  // Column-major styled write
+  ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+  // Scatter styled write
+  ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+  return
+}
+
+/// Check write op assembly with tensors
+func.func @write_ops_tensor(%value: tensor<8xf32>, %ptr: tensor<8x!ptr.ptr<#ptr.generic_space>>, %mask: tensor<8xi1>) {
+  ptr.write %value, %ptr, %mask contiguity = [1] : tensor<8xf32>, tensor<8x!ptr.ptr<#ptr.generic_space>>
+  ptr.write %value, %ptr, %mask alignment = 4 contiguity = [8] : tensor<8xf32>, tensor<8x!ptr.ptr<#ptr.generic_space>>
+  return
+}



More information about the Mlir-commits mailing list