[Mlir-commits] [mlir] [mlir][ptr] Add high-level read and write memory operations (PR #161081)
Fabian Mora
llvmlistbot at llvm.org
Sun Sep 28 06:46:11 PDT 2025
https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/161081
Add `ptr.read` and `ptr.write` operations to the pointer dialect. These operations
provide a high-level interface for reading from and writing to memory with:
- Masked access semantics (conditional loads/stores)
- Contiguity information for optimized lowering
- Support for both vector and tensor types
- Ability to express row-major, column-major, and gather/scatter patterns
It's future work to add lowerings to ptr.load/store, ptr.masked_load/store, or
ptr.gather/scatter depending on mask and contiguity properties.
Example:
```mlir
func.func @read(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) {
// Row-major styled read
%0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
// Column-major styled read
%1 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [4, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
// Gather styled read
%2 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [1, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
return
}
func.func @write(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
// Row-major styled write
ptr.write %value, %ptr, %mask contiguity = [1, 4] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
// Column-major styled write
ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
// Scatter styled write
ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
return
}
```
>From 2b8c89d36e3e4576b6d49778bc2c0df8c2faeb33 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sun, 28 Sep 2025 13:41:26 +0000
Subject: [PATCH] add read-write ops
---
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 248 +++++++++++++++++++++
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 105 +++++++++
mlir/test/Dialect/Ptr/invalid.mlir | 48 ++++
mlir/test/Dialect/Ptr/ops.mlir | 36 +++
4 files changed, 437 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index e14f64330c294..c3a5415d0cbc8 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -24,6 +24,8 @@ include "mlir/IR/OpAsmInterface.td"
def AlignmentProp : OptionalProp<I64Prop>;
+def ContiguityProp : IntArrayProp<I32Prop, "memory access contiguity information">;
+
//===----------------------------------------------------------------------===//
// Common types
//===----------------------------------------------------------------------===//
@@ -45,6 +47,15 @@ def Ptr_IntLikeType :AnyTypeOf<[
AnySignlessIntegerOrIndex
]>;
+// A shaped pointer type with value semantics.
+def Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
+
+// A shaped mask type with value semantics.
+def Ptr_ShapedMaskType : Ptr_ShapedValueType<[I1], [HasRankPred]>;
+
+// A shaped mask type with value semantics.
+def Ptr_ShapedAnyType : Ptr_ShapedValueType<[AnyType], [HasRankPred]>;
+
// A shaped value type of rank 1 of any element type.
def Ptr_Any1DType :
Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
@@ -472,6 +483,127 @@ def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// ReadOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ReadOp : Pointer_Op<"read", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
+ ::llvm::cast<ShapedType>($_self).clone(
+ IntegerType::get($_self.getContext(), 1))
+ }]>,
+ AllTypesMatch<["result", "passthrough"]>,
+ // Check the shapes are compatible and both use the same shaped container
+ // type.
+ AllShapesMatch<["result", "ptr"]>, AllTypeIDsMatch<["result", "ptr"]>
+ ]> {
+ let summary = "Read operation";
+ let description = [{
+ The `read` operation is a high-level operation that performs a read
+ from multiple memory locations specified by `ptr` based on a mask `mask`.
+ Elements of the `result`, corresponding to masked-off lanes, are taken from
+ the `passthrough` operand.
+
+ The `mask` operand is a shaped type of `i1` elements that must have the same
+ shape as the result type.
+
+ The `contiguity` property is an integer array with the same rank as `ptr`,
+ where each element describes memory access contiguity for the corresponding
+ dimension. The precise semantics of this property are given by:
+ Let `c1, c2, ..., cn` be the elements of the contiguity array, and
+ `s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
+ The following rules and restrictions apply:
+ 1. `ck` must be strictly positive for all k.
+ 2. `ck` must divide `sk` for all k.
+ 3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
+ given by:
+ - `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, ji+1, ..., jn]` for
+ `l = 0, 1, ..., sk / ck - 1`
+ are contiguous for all k.
+
+ It is undefined behavior if the pointers in `ptr` do not satisfy the
+ contiguity constraints specified by `contiguity`.
+
+ Depending on the values of `mask` and `contiguity`, the operation can be
+ lowered to either:
+ 1. A `ptr.load`, if the mask is all ones, and there's a dimension where all
+ the accesses are contiguous.
+ 2. A `ptr.masked_load`, if the mask is not all ones, and there's a dimension
+ where all the accesses are contiguous.
+ 3. A `ptr.gather` if the mask is not all ones, and there's no contiguous
+ dimension.
+
+ The alignment property describes the alignment (in bytes) of each contiguous
+ memory-block being accessed.
+
+ Examples:
+ ```mlir
+ // Read a vector in row-major order
+ %result = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] :
+ vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+
+ // Read a vector in column-major order with alignment
+ %result = ptr.read %ptr, %mask, %passthrough alignment = 8
+ contiguity = [4, 1] :
+ vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+
+ // Gather a vector from memory
+ %result = ptr.read %ptr, %mask, %passthrough alignment = 8
+ contiguity = [1, 1] :
+ vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+ ```
+ }];
+ let arguments = (ins Ptr_ShapedPtrType:$ptr,
+ Ptr_ShapedMaskType:$mask,
+ Ptr_ShapedAnyType:$passthrough,
+ AlignmentProp:$alignment,
+ ContiguityProp:$contiguity);
+ let results = (outs Ptr_ShapedAnyType:$result);
+ let assemblyFormat = [{
+ $ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
+ `contiguity` `=` $contiguity attr-dict `:` type($ptr) `->` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$passthrough,
+ CArg<"unsigned", "0">:$alignment,
+ CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
+ ];
+ let hasVerifier = 1;
+ let extraClassDeclaration = [{
+ /// Returns the ptr type of the operation.
+ PtrType getPtrType() {
+ return cast<PtrType>(getPtr().getType().getElementType());
+ }
+
+ /// Returns the rank of the shaped operands and result.
+ unsigned getRank() { return getType().getRank(); }
+
+ /// Returns the shape of the shaped operands and result.
+ ArrayRef<int64_t> getShape() { return getType().getShape(); }
+
+ /// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
+ /// of the `i`-th dimension.
+ std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
+ assert(i < getRank() && "Invalid dimension");
+ return {getContiguity()[i], getShape()[i]};
+ }
+
+ /// Returns true if the `i`-th dimension is contiguous.
+ bool isContiguous(unsigned i) {
+ auto [contiguity, size] = getContiguityInfo(i);
+ return contiguity == size && size > 1;
+ }
+
+ /// Returns true if the read has gather semantics, ie. there's no dimension
+ /// where all the accesses are contiguous.
+ bool hasGatherSemantics() {
+ return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
+ [this](unsigned i) { return isContiguous(i); });
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
@@ -645,4 +777,120 @@ def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> {
}];
}
+//===----------------------------------------------------------------------===//
+// WriteOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_WriteOp : Pointer_Op<"write", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TypesMatchWith<"value and mask must be compatible",
+ "value", "mask", [{
+ cast<ShapedType>($_self).clone(IntegerType::get($_self.getContext(), 1))
+ }]>,
+ // Check the shapes are compatible and both use the same shaped container
+ AllShapesMatch<["value", "ptr"]>, AllTypeIDsMatch<["value", "ptr"]>
+ ]> {
+ let summary = "Write operation";
+ let description = [{
+ The `write` operation is a high-level operation that performs a write to
+ multiple memory locations specified by `ptr` based on a mask `mask`.
+ Elements of the `value`, corresponding to masked-off lanes, are not written
+ to memory.
+
+ The `mask` operand is a shaped type of `i1` elements that must have the same
+ shape as the `value` type.
+
+ The `contiguity` property is an integer array with the same rank as `ptr`,
+ where each element describes memory access contiguity for the corresponding
+ dimension. The precise semantics of this property are given by:
+ Let `c1, c2, ..., cn` be the elements of the contiguity array, and
+ `s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
+ The following rules and restrictions apply:
+ 1. `ck` must be strictly positive for all k.
+ 2. `ck` must divide `sk` for all k.
+ 3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
+ given by:
+ - `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, ji+1, ..., jn]` for
+ `l = 0, 1, ..., sk / ck - 1`
+ are contiguous for all k.
+
+ It is undefined behavior if the pointers in `ptr` do not satisfy the
+ contiguity constraints specified by `contiguity`.
+
+ Depending on the values of `mask` and `contiguity`, the operation can be
+ lowered to either:
+ 1. A `ptr.store`, if the mask is all ones, and there's a dimension where all
+ the accesses are contiguous.
+ 2. A `ptr.masked_store`, if the mask is not all ones, and there's a dimension
+ where all the accesses are contiguous.
+ 3. A `ptr.scatter` if the mask is not all ones, and there's no contiguous
+ dimension.
+
+ The alignment property describes the alignment (in bytes) of each contiguous
+ memory-block being accessed.
+
+ Example:
+ ```mlir
+ // Write a vector in row-major order
+ ptr.write %value, %ptr, %mask contiguity = [1, 4] :
+ vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+
+ // Write a vector in column-major order with alignment
+ ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] :
+ vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+
+ // Scatter a vector to memory
+ ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] :
+ vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+ ```
+ }];
+ let arguments = (ins Ptr_ShapedAnyType:$value,
+ Ptr_ShapedPtrType:$ptr,
+ Ptr_ShapedMaskType:$mask,
+ AlignmentProp:$alignment,
+ ContiguityProp:$contiguity);
+ let assemblyFormat = [{
+ $value `,` $ptr `,` $mask (`alignment` `=` $alignment^)?
+ `contiguity` `=` $contiguity attr-dict `:` type($value) `,` type($ptr)
+ }];
+ let builders = [
+ OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask,
+ CArg<"unsigned", "0">:$alignment,
+ CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
+ ];
+ let hasVerifier = 1;
+ let extraClassDeclaration = [{
+ /// Returns the ptr type of the operation.
+ PtrType getPtrType() {
+ return cast<PtrType>(getPtr().getType().getElementType());
+ }
+
+ /// Returns the rank of the shaped operands.
+ unsigned getRank() { return getPtr().getType().getRank(); }
+
+ /// Returns the shape of the shaped operands.
+ ArrayRef<int64_t> getShape() { return getPtr().getType().getShape(); }
+
+ /// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
+ /// of the `i`-th dimension.
+ std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
+ assert(i < getRank() && "Invalid dimension");
+ return {getContiguity()[i], getShape()[i]};
+ }
+
+ /// Returns true if the `i`-th dimension is contiguous.
+ bool isContiguous(unsigned i) {
+ auto [contiguity, size] = getContiguityInfo(i);
+ return contiguity == size && size > 1;
+ }
+
+ /// Returns true if the write has scatter semantics, ie. there's no
+ /// dimension where all the accesses are contiguous.
+ bool hasScatterSemantics() {
+ return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
+ [this](unsigned i) { return isContiguous(i); });
+ }
+ }];
+}
+
#endif // PTR_OPS
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 51f25f755a8a6..ecfbd957bbe24 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -57,6 +57,25 @@ verifyAlignment(std::optional<int64_t> alignment,
return success();
}
+/// Verifies that the contiguity array has the right size, all the elements are
+/// positive and divide the corresponding shape dimension.
+static LogicalResult
+verifyContiguityProp(ArrayRef<int32_t> contiguity, ArrayRef<int64_t> shape,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (contiguity.size() != shape.size()) {
+ return emitError() << "expected contiguity array with " << shape.size()
+ << " elements";
+ }
+ if (!llvm::all_of(llvm::zip(contiguity, shape), [](auto cs) {
+ int32_t c = std::get<0>(cs);
+ return c > 0 && std::get<1>(cs) % c == 0;
+ })) {
+ return emitError()
+ << "expected contiguity values to be positive and divide the shape";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@@ -264,6 +283,49 @@ void MaskedStoreOp::build(OpBuilder &builder, OperationState &state,
alignment ? std::optional<int64_t>(alignment) : std::nullopt);
}
+//===----------------------------------------------------------------------===//
+// ReadOp
+//===----------------------------------------------------------------------===//
+
+void ReadOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
+}
+
+LogicalResult ReadOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+ // Verify that the pointer type's memory space allows loads.
+ MemorySpaceAttrInterface ms =
+ cast<PtrType>(getPtr().getType().getElementType()).getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
+ getAlignment(), &dataLayout, emitDiag))
+ return failure();
+
+ // Verify the alignment.
+ if (failed(verifyAlignment(getAlignment(), emitDiag)))
+ return failure();
+
+ // Verify the contiguity array.
+ return verifyContiguityProp(getContiguity(), getShape(), emitDiag);
+}
+
+void ReadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
+ Value mask, Value passthrough, unsigned alignment,
+ ArrayRef<int32_t> contiguity) {
+ if (!contiguity.empty()) {
+ build(builder, state, ptr, mask, passthrough,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+ contiguity);
+ return;
+ }
+ build(builder, state, ptr, mask, passthrough,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+ SmallVector<int32_t>(cast<ShapedType>(ptr.getType()).getRank(), 1));
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
@@ -470,6 +532,49 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
return dl.getTypeSize(getElementType());
}
+//===----------------------------------------------------------------------===//
+// WriteOp
+//===----------------------------------------------------------------------===//
+
+void WriteOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
+}
+
+LogicalResult WriteOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+ // Verify that the pointer type's memory space allows stores.
+ MemorySpaceAttrInterface ms =
+ cast<PtrType>(getPtr().getType().getElementType()).getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
+ getAlignment(), &dataLayout, emitDiag))
+ return failure();
+
+ // Verify the alignment.
+ if (failed(verifyAlignment(getAlignment(), emitDiag)))
+ return failure();
+
+ // Verify the contiguity array.
+ return verifyContiguityProp(getContiguity(), getShape(), emitDiag);
+}
+
+void WriteOp::build(OpBuilder &builder, OperationState &state, Value value,
+ Value ptr, Value mask, unsigned alignment,
+ ArrayRef<int32_t> contiguity) {
+ if (!contiguity.empty()) {
+ build(builder, state, value, ptr, mask,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+ contiguity);
+ return;
+ }
+ build(builder, state, value, ptr, mask,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+ SmallVector<int32_t>(cast<ShapedType>(ptr.getType()).getRank(), 1));
+}
+
//===----------------------------------------------------------------------===//
// Pointer API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index 83e1c880650c5..54332a5632808 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -78,3 +78,51 @@ func.func @ptr_diff_mismatch(%lhs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %rhs:
%res = ptr.ptr_diff %lhs, %rhs : tensor<8x!ptr.ptr<#ptr.generic_space>> -> vector<8xi64>
return %res : vector<8xi64>
}
+
+// -----
+
+func.func @read_contiguity_does_not_divide(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> {
+ // expected-error at +1 {{expected contiguity values to be positive and divide the shape}}
+ %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, 3] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// -----
+
+func.func @read_contiguity_is_not_positive(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> {
+ // expected-error at +1 {{expected contiguity values to be positive and divide the shape}}
+ %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, -1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// -----
+
+func.func @read_invalid_contiguity_size(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> {
+ // expected-error at +1 {{expected contiguity array with 2 elements}}
+ %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// -----
+
+func.func @write_contiguity_does_not_divide(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
+ // expected-error at +1 {{expected contiguity values to be positive and divide the shape}}
+ ptr.write %value, %ptr, %mask contiguity = [1, 7] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+ return
+}
+
+// -----
+
+func.func @write_contiguity_is_not_positive(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
+ // expected-error at +1 {{expected contiguity values to be positive and divide the shape}}
+ ptr.write %value, %ptr, %mask contiguity = [0, 4] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+ return
+}
+
+// -----
+
+func.func @write_invalid_contiguity_size(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
+ // expected-error at +1 {{expected contiguity array with 2 elements}}
+ ptr.write %value, %ptr, %mask contiguity = [1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+ return
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 0a906ad559e21..d0c0390d6932e 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -239,3 +239,39 @@ func.func @ptr_diff_tensor_2d_ops(%ptrs1: tensor<4x8x!ptr.ptr<#ptr.generic_space
%diff = ptr.ptr_diff %ptrs1, %ptrs2 : tensor<4x8x!ptr.ptr<#ptr.generic_space>> -> tensor<4x8xi64>
return %diff : tensor<4x8xi64>
}
+
+/// Check read op assembly.
+func.func @read_ops(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) {
+ // Row-major styled read
+ %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+ // Column-major styled read
+ %1 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [4, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+ // Gather styled read
+ %2 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [1, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+ return
+}
+
+/// Check read op assembly with tensors
+func.func @read_ops_tensor(%ptr: tensor<8x!ptr.ptr<#ptr.generic_space>>, %mask: tensor<8xi1>, %passthrough: tensor<8xf32>) {
+ %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1] : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xf32>
+ %1 = ptr.read %ptr, %mask, %passthrough alignment = 4 contiguity = [8] : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xf32>
+ return
+}
+
+/// Check write op assembly.
+func.func @write_ops(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
+ // Row-major styled write
+ ptr.write %value, %ptr, %mask contiguity = [1, 4] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+ // Column-major styled write
+ ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+ // Scatter styled write
+ ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+ return
+}
+
+/// Check write op assembly with tensors
+func.func @write_ops_tensor(%value: tensor<8xf32>, %ptr: tensor<8x!ptr.ptr<#ptr.generic_space>>, %mask: tensor<8xi1>) {
+ ptr.write %value, %ptr, %mask contiguity = [1] : tensor<8xf32>, tensor<8x!ptr.ptr<#ptr.generic_space>>
+ ptr.write %value, %ptr, %mask alignment = 4 contiguity = [8] : tensor<8xf32>, tensor<8x!ptr.ptr<#ptr.generic_space>>
+ return
+}
More information about the Mlir-commits
mailing list