[Mlir-commits] [mlir] [mlir][ptr] Add `gather`, `masked_load`, `masked_store`, and `scatter` ops (PR #156368)
Fabian Mora
llvmlistbot at llvm.org
Wed Sep 3 07:30:58 PDT 2025
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/156368
>From 6d1721146c9996fdfd0928d53f819764a8e22b1c Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sun, 31 Aug 2025 12:01:19 +0000
Subject: [PATCH 1/3] Add load, store variant ops
---
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 238 +++++++++++++++++-
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 156 +++++++++++-
.../Dialect/Ptr/PtrToLLVMIRTranslation.cpp | 119 +++++++++
mlir/test/Dialect/Ptr/ops.mlir | 70 ++++++
mlir/test/Target/LLVMIR/ptr.mlir | 114 +++++++++
5 files changed, 682 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 1c88efced950e..170513d57c7be 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -17,6 +17,46 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/OpAsmInterface.td"
+//===----------------------------------------------------------------------===//
+// Common props
+//===----------------------------------------------------------------------===//
+
+def AlignmentProp : OptionalProp<I64Prop>;
+
+//===----------------------------------------------------------------------===//
+// Common types
+//===----------------------------------------------------------------------===//
+
+// A shaped value type with value semantics and rank.
+class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
+ ShapedContainerType<allowedTypes,
+ /*containerPred=*/And<[HasValueSemanticsPred] # preds>,
+ /*descr=*/[{A shaped type with value semantics and rank.}],
+ /*cppType=*/"::mlir::ShapedType">;
+
+// A shaped pointer type with value semantics and rank.
+class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
+
+// A shaped value type of rank 1 of any element type.
+def Ptr_Any1DType :
+ Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
+
+// A shaped value type of rank 1 of `i1` element type.
+def Ptr_Mask1DType :
+ Ptr_ShapedValueType<[I1], [HasAnyRankOfPred<[1]>]>;
+
+// A shaped value type of rank 1 of `i1` element type.
+def Ptr_Ptr1DType :
+ Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
+
+// Gets the type ID of a type.
+class TypeIDType<string name> :
+ StrFunc<"$" # name # ".getType().getTypeID()">;
+
+// Checks that all type IDs match.
+class AllTypeIDsMatch<list<string> names> :
+ AllMatchSameOperatorTrait<names, TypeIDType<"_self">.result, "type IDs">;
+
//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//
@@ -56,6 +96,58 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_GatherOp : Pointer_Op<"gather", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
+ ::llvm::cast<ShapedType>($_self).clone(
+ IntegerType::get($_self.getContext(), 1))
+ }]>,
+ AllTypesMatch<["result", "passthrough"]>,
+ // Check the shapes are compatible and both use the same shaped container
+ // type.
+ AllShapesMatch<["result", "ptrs"]>, AllTypeIDsMatch<["result", "ptrs"]>
+ ]> {
+ let summary = "Gather operation";
+ let description = [{
+ The `gather` operation performs conditional loads from multiple memory
+ locations specified by `ptrs` based on a mask `mask`. Elements of the
+ result corresponding to masked-off lanes are taken from the passthrough
+ operand.
+
+ The mask operand is a shaped type of `i1` elements that must have the same
+ shape as the result type.
+
+ Examples:
+ ```mlir
+ // Gather values from multiple memory locations
+ %result = ptr.gather %ptrs, %mask, %passthrough :
+ vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
+
+ // Gather with alignment
+ %result = ptr.gather %ptrs, %mask, %passthrough alignment = 8 :
+ vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
+ ```
+ }];
+ let arguments = (ins Ptr_Ptr1DType:$ptrs,
+ Ptr_Mask1DType:$mask,
+ Ptr_Any1DType:$passthrough,
+ AlignmentProp:$alignment);
+ let results = (outs Ptr_Any1DType:$result);
+ let assemblyFormat = [{
+ $ptrs `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
+ attr-dict `:` qualified(type($ptrs)) `->` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask,
+ "Value":$passthrough, CArg<"unsigned", "0">:$alignment)>
+ ];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// GetMetadataOp
//===----------------------------------------------------------------------===//
@@ -122,8 +214,6 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
// LoadOp
//===----------------------------------------------------------------------===//
-def AlignmentProp : OptionalProp<I64Prop>;
-
def Ptr_LoadOp : Pointer_Op<"load", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
@@ -184,6 +274,150 @@ def Ptr_LoadOp : Pointer_Op<"load", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// MaskedLoadOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_MaskedLoadOp : Pointer_Op<"masked_load", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
+ ::llvm::cast<ShapedType>($_self).clone(
+ IntegerType::get($_self.getContext(), 1))
+ }]>,
+ AllTypesMatch<["result", "passthrough"]>
+ ]> {
+ let summary = "Masked load operation";
+ let description = [{
+ The `masked_load` operation performs a conditional load from memory based
+ on a mask. Elements of the result corresponding to masked-off lanes are
+ taken from the passthrough operand.
+
+ The mask operand is a shaped type of `i1` elements that must have the same
+ shape as the result type.
+
+ Examples:
+ ```mlir
+ // Masked load with passthrough on vectors
+ %result = ptr.masked_load %ptr, %mask, %passthrough :
+ !ptr.ptr<#ptr.generic_space> -> vector<4xf32>
+
+ // Masked load with passthrough on tensors
+ %result = ptr.masked_load %ptr, %mask, %passthrough :
+ !ptr.ptr<#ptr.generic_space> -> tensor<4xf32>
+ ```
+ }];
+ let arguments = (ins Ptr_PtrType:$ptr,
+ Ptr_Mask1DType:$mask,
+ Ptr_Any1DType:$passthrough,
+ AlignmentProp:$alignment);
+ let results = (outs Ptr_Any1DType:$result);
+ let assemblyFormat = [{
+ $ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
+ attr-dict `:` qualified(type($ptr)) `->` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask,
+ "Value":$passthrough, CArg<"unsigned", "0">:$alignment)>
+ ];
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// MaskedStoreOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TypesMatchWith<"value and mask must be compatible", "value", "mask", [{
+ ::llvm::cast<ShapedType>($_self).clone(
+ IntegerType::get($_self.getContext(), 1))
+ }]>
+ ]> {
+ let summary = "Masked store operation";
+ let description = [{
+ The `masked_store` operation performs a conditional store to memory based
+ on a mask. Only elements corresponding to set bits in the mask are written
+ to memory.
+
+ The mask operand is a shaped type of `i1` elements that must have the same
+ shape as the value being stored.
+
+ Examples:
+ ```mlir
+ // Masked store
+ ptr.masked_store %value, %ptr, %mask :
+ vector<4xf32>, !ptr.ptr<#ptr.generic_space>
+
+ // Masked store with alignment
+ ptr.masked_store %value, %ptr, %mask alignment = 8 :
+ vector<4xf32>, !ptr.ptr<#ptr.generic_space>
+ ```
+ }];
+
+ let arguments = (ins Ptr_Any1DType:$value,
+ Ptr_PtrType:$ptr,
+ Ptr_Mask1DType:$mask,
+ AlignmentProp:$alignment);
+ let assemblyFormat = [{
+ $value `,` $ptr `,` $mask (`alignment` `=` $alignment^)? attr-dict `:`
+ type($value) `,` qualified(type($ptr))
+ }];
+ let builders = [
+ OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask,
+ CArg<"unsigned", "0">:$alignment)>
+ ];
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ScatterOp : Pointer_Op<"scatter", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TypesMatchWith<"value and mask must be compatible", "value", "mask", [{
+ ::llvm::cast<ShapedType>($_self).clone(
+ IntegerType::get($_self.getContext(), 1))
+ }]>,
+ // Check the shapes are compatible and both use the same shaped container
+ // type.
+ AllShapesMatch<["value", "ptrs"]>, AllTypeIDsMatch<["value", "ptrs"]>
+ ]> {
+ let summary = "Scatter operation";
+ let description = [{
+ The `scatter` operation performs a conditional store of a value `value` to
+ multiple memory locations specified by `ptrs` based on a mask `mask`.
+
+ Only elements corresponding to set bits in the mask are written to memory.
+ The mask operand is a shaped type of `i1` elements that must have the same
+ shape as the value being stored.
+
+ Examples:
+ ```mlir
+ // Scatter values to multiple memory locations
+ ptr.scatter %value, %ptrs, %mask :
+ vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
+
+ // Scatter with alignment
+ ptr.scatter %value, %ptrs, %mask alignment = 8 :
+ vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
+ ```
+ }];
+ let arguments = (ins Ptr_Any1DType:$value,
+ Ptr_Ptr1DType:$ptrs,
+ Ptr_Mask1DType:$mask,
+ AlignmentProp:$alignment);
+ let assemblyFormat = [{
+ $value `,` $ptrs `,` $mask (`alignment` `=` $alignment^)?
+ attr-dict `:` type($value) `,` qualified(type($ptrs))
+ }];
+ let builders = [
+ OpBuilder<(ins "Value":$value, "Value":$ptrs, "Value":$mask,
+ CArg<"unsigned", "0">:$alignment)>
+ ];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index d5976b9a41ff6..74e35f48b033c 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -39,6 +39,23 @@ void PtrDialect::initialize() {
>();
}
+//===----------------------------------------------------------------------===//
+// Common helper functions.
+//===----------------------------------------------------------------------===//
+
+/// Verifies that the alignment attribute is a power of 2 if present.
+static LogicalResult
+verifyAlignment(std::optional<int64_t> alignment,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (!alignment)
+ return success();
+ if (alignment.value() <= 0)
+ return emitError() << "alignment must be positive";
+ if (!llvm::isPowerOf2_64(alignment.value()))
+ return emitError() << "alignment must be a power of 2";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//
@@ -84,6 +101,39 @@ LogicalResult FromPtrOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+void GatherOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // Gather performs reads from multiple memory locations specified by ptrs
+ effects.emplace_back(MemoryEffects::Read::get(), &getPtrsMutable());
+}
+
+LogicalResult GatherOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+ // Verify that the pointer type's memory space allows loads.
+ MemorySpaceAttrInterface ms =
+ cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
+ getAlignment(), dataLayout, emitDiag))
+ return failure();
+
+ // Verify the alignment.
+ return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void GatherOp::build(OpBuilder &builder, OperationState &state, Type resultType,
+ Value ptrs, Value mask, Value passthrough,
+ unsigned alignment) {
+ build(builder, state, resultType, ptrs, mask, passthrough,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
+
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
@@ -107,19 +157,6 @@ verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> unsupportedOrderings) {
return success();
}
-/// Verifies that the alignment attribute is a power of 2 if present.
-static LogicalResult
-verifyAlignment(std::optional<int64_t> alignment,
- function_ref<InFlightDiagnostic()> emitError) {
- if (!alignment)
- return success();
- if (alignment.value() <= 0)
- return emitError() << "alignment must be positive";
- if (!llvm::isPowerOf2_64(alignment.value()))
- return emitError() << "alignment must be a power of 2";
- return success();
-}
-
void LoadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
@@ -158,6 +195,99 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering,
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
}
+//===----------------------------------------------------------------------===//
+// MaskedLoadOp
+//===----------------------------------------------------------------------===//
+
+void MaskedLoadOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // MaskedLoad performs reads from the memory location specified by ptr.
+ effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
+}
+
+LogicalResult MaskedLoadOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+ // Verify that the pointer type's memory space allows loads.
+ MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
+ getAlignment(), dataLayout, emitDiag))
+ return failure();
+
+ // Verify the alignment.
+ return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void MaskedLoadOp::build(OpBuilder &builder, OperationState &state,
+ Type resultType, Value ptr, Value mask,
+ Value passthrough, unsigned alignment) {
+ build(builder, state, resultType, ptr, mask, passthrough,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
+
+//===----------------------------------------------------------------------===//
+// MaskedStoreOp
+//===----------------------------------------------------------------------===//
+
+void MaskedStoreOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // MaskedStore performs writes to the memory location specified by ptr
+ effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
+}
+
+LogicalResult MaskedStoreOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+ // Verify that the pointer type's memory space allows stores.
+ MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
+ getAlignment(), dataLayout, emitDiag))
+ return failure();
+
+ // Verify the alignment.
+ return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void MaskedStoreOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value ptr, Value mask,
+ unsigned alignment) {
+ build(builder, state, value, ptr, mask,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+void ScatterOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // Scatter performs writes to multiple memory locations specified by ptrs
+ effects.emplace_back(MemoryEffects::Write::get(), &getPtrsMutable());
+}
+
+LogicalResult ScatterOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+ // Verify that the pointer type's memory space allows stores.
+ MemorySpaceAttrInterface ms =
+ cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
+ getAlignment(), dataLayout, emitDiag))
+ return failure();
+
+ // Verify the alignment.
+ return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void ScatterOp::build(OpBuilder &builder, OperationState &state, Value value,
+ Value ptrs, Value mask, unsigned alignment) {
+ build(builder, state, value, ptrs, mask,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
//===----------------------------------------------------------------------===//
// StoreOp
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index 6bcb293a7d821..3bcc16927b352 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -204,6 +204,112 @@ convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
return success();
}
+/// Convert ptr.gather operation
+static LogicalResult
+convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs());
+ llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask());
+ llvm::Value *passthrough =
+ moduleTranslation.lookupValue(gatherOp.getPassthrough());
+
+ if (!ptrs || !mask || !passthrough)
+ return gatherOp.emitError("Failed to lookup operands");
+
+ // Convert result type to LLVM type.
+ llvm::Type *resultType =
+ moduleTranslation.convertType(gatherOp.getResult().getType());
+ if (!resultType)
+ return gatherOp.emitError("Failed to convert result type");
+
+ // Get the alignment.
+ llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0));
+
+ // Create the masked gather intrinsic call.
+ llvm::Value *result = builder.CreateMaskedGather(
+ resultType, ptrs, alignment.valueOrOne(), mask, passthrough);
+
+ moduleTranslation.mapValue(gatherOp.getResult(), result);
+ return success();
+}
+
+/// Convert ptr.masked_load operation
+static LogicalResult
+convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr());
+ llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask());
+ llvm::Value *passthrough =
+ moduleTranslation.lookupValue(maskedLoadOp.getPassthrough());
+
+ if (!ptr || !mask || !passthrough)
+ return maskedLoadOp.emitError("Failed to lookup operands");
+
+ // Convert result type to LLVM type.
+ llvm::Type *resultType =
+ moduleTranslation.convertType(maskedLoadOp.getResult().getType());
+ if (!resultType)
+ return maskedLoadOp.emitError("Failed to convert result type");
+
+ // Get the alignment.
+ llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0));
+
+ // Create the masked load intrinsic call.
+ llvm::Value *result = builder.CreateMaskedLoad(
+ resultType, ptr, alignment.valueOrOne(), mask, passthrough);
+
+ moduleTranslation.mapValue(maskedLoadOp.getResult(), result);
+ return success();
+}
+
+/// Convert ptr.masked_store operation
+static LogicalResult
+convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue());
+ llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr());
+ llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask());
+
+ if (!value || !ptr || !mask)
+ return maskedStoreOp.emitError("Failed to lookup operands");
+
+ // Get the value type.
+ llvm::Type *valueType = value->getType();
+ if (!valueType)
+ return maskedStoreOp.emitError("Failed to get value type");
+
+ // Get the alignment.
+ llvm::MaybeAlign alignment(maskedStoreOp.getAlignment().value_or(0));
+
+ // Create the masked store intrinsic call.
+ builder.CreateMaskedStore(value, ptr, alignment.valueOrOne(), mask);
+ return success();
+}
+
+/// Convert ptr.scatter operation
+static LogicalResult
+convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue());
+ llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs());
+ llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask());
+
+ if (!value || !ptrs || !mask)
+ return scatterOp.emitError("Failed to lookup operands");
+
+ // Get the value type
+ llvm::Type *valueType = value->getType();
+ if (!valueType)
+ return scatterOp.emitError("Failed to get value type");
+
+ // Get the alignment.
+ llvm::MaybeAlign alignment(scatterOp.getAlignment().value_or(0));
+
+ // Create the masked scatter intrinsic call.
+ builder.CreateMaskedScatter(value, ptrs, alignment.valueOrOne(), mask);
+ return success();
+}
+
/// Implementation of the dialect interface that converts operations belonging
/// to the `ptr` dialect to LLVM IR.
class PtrDialectLLVMIRTranslationInterface
@@ -230,6 +336,19 @@ class PtrDialectLLVMIRTranslationInterface
.Case([&](TypeOffsetOp typeOffsetOp) {
return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation);
})
+ .Case<GatherOp>([&](GatherOp gatherOp) {
+ return convertGatherOp(gatherOp, builder, moduleTranslation);
+ })
+ .Case<MaskedLoadOp>([&](MaskedLoadOp maskedLoadOp) {
+ return convertMaskedLoadOp(maskedLoadOp, builder, moduleTranslation);
+ })
+ .Case<MaskedStoreOp>([&](MaskedStoreOp maskedStoreOp) {
+ return convertMaskedStoreOp(maskedStoreOp, builder,
+ moduleTranslation);
+ })
+ .Case<ScatterOp>([&](ScatterOp scatterOp) {
+ return convertScatterOp(scatterOp, builder, moduleTranslation);
+ })
.Default([&](Operation *op) {
return op->emitError("Translation for operation '")
<< op->getName() << "' is not implemented.";
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 3f3ad05c46acc..bde2fb22b6aa0 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -56,3 +56,73 @@ func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<2>>, %arg1: f32, %arg2
ptr.store %arg2, %arg0 atomic release alignment = 8 : i64, !ptr.ptr<#llvm.address_space<2>>
return
}
+
+/// Test gather operations
+func.func @gather_ops(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4xi1>, %passthrough: vector<4xf32>) -> vector<4xf32> {
+ %0 = ptr.gather %ptrs, %mask, %passthrough : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
+ %1 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+/// Test gather operations with tensors
+func.func @gather_ops_tensor(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %mask: tensor<8xi1>, %passthrough: tensor<8xi32>) -> tensor<8xi32> {
+ %0 = ptr.gather %ptrs, %mask, %passthrough : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi32>
+ %1 = ptr.gather %ptrs, %mask, %passthrough alignment = 4 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi32>
+ return %0 : tensor<8xi32>
+}
+
+/// Test scatter operations
+func.func @scatter_ops(%value: vector<4xf32>, %ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4xi1>) {
+ ptr.scatter %value, %ptrs, %mask : vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
+ ptr.scatter %value, %ptrs, %mask alignment = 16 : vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
+ return
+}
+
+/// Test scatter operations with tensors
+func.func @scatter_ops_tensor(%value: tensor<8xi64>, %ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %mask: tensor<8xi1>) {
+ ptr.scatter %value, %ptrs, %mask : tensor<8xi64>, tensor<8x!ptr.ptr<#ptr.generic_space>>
+ ptr.scatter %value, %ptrs, %mask alignment = 8 : tensor<8xi64>, tensor<8x!ptr.ptr<#ptr.generic_space>>
+ return
+}
+
+/// Test masked load operations
+func.func @masked_load_ops(%ptr: !ptr.ptr<#ptr.generic_space>, %mask: vector<4xi1>, %passthrough: vector<4xf32>) -> vector<4xf32> {
+ %0 = ptr.masked_load %ptr, %mask, %passthrough : !ptr.ptr<#ptr.generic_space> -> vector<4xf32>
+ %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 16 : !ptr.ptr<#ptr.generic_space> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+/// Test masked load operations with tensors
+func.func @masked_load_ops_tensor(%ptr: !ptr.ptr<#ptr.generic_space>, %mask: tensor<8xi1>, %passthrough: tensor<8xi32>) -> tensor<8xi32> {
+ %0 = ptr.masked_load %ptr, %mask, %passthrough : !ptr.ptr<#ptr.generic_space> -> tensor<8xi32>
+ %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 4 : !ptr.ptr<#ptr.generic_space> -> tensor<8xi32>
+ return %0 : tensor<8xi32>
+}
+
+/// Test masked store operations
+func.func @masked_store_ops(%value: vector<4xf32>, %ptr: !ptr.ptr<#ptr.generic_space>, %mask: vector<4xi1>) {
+ ptr.masked_store %value, %ptr, %mask : vector<4xf32>, !ptr.ptr<#ptr.generic_space>
+ ptr.masked_store %value, %ptr, %mask alignment = 32 : vector<4xf32>, !ptr.ptr<#ptr.generic_space>
+ return
+}
+
+/// Test masked store operations with tensors
+func.func @masked_store_ops_tensor(%value: tensor<8xi64>, %ptr: !ptr.ptr<#ptr.generic_space>, %mask: tensor<8xi1>) {
+ ptr.masked_store %value, %ptr, %mask : tensor<8xi64>, !ptr.ptr<#ptr.generic_space>
+ ptr.masked_store %value, %ptr, %mask alignment = 8 : tensor<8xi64>, !ptr.ptr<#ptr.generic_space>
+ return
+}
+
+/// Test operations with LLVM address space
+func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
+ %mask: vector<4xi1>, %value: vector<4xf32>, %passthrough: vector<4xf32>) -> vector<4xf32> {
+ // Gather from shared memory (address space 3)
+ %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 4 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf32>
+ // Scatter to shared memory
+ ptr.scatter %value, %ptrs, %mask alignment = 4 : vector<4xf32>, vector<4x!ptr.ptr<#llvm.address_space<3>>>
+ // Masked load from shared memory
+ %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 4 : !ptr.ptr<#llvm.address_space<3>> -> vector<4xf32>
+ // Masked store to shared memory
+ ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<4xf32>, !ptr.ptr<#llvm.address_space<3>>
+ return %0 : vector<4xf32>
+}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 6e3b365b862e2..545bec5979b2d 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -89,3 +89,117 @@ llvm.func @store_ops(%arg0: !ptr.ptr<#llvm.address_space<0>>, %arg1: f32, %arg2:
ptr.store volatile %arg3, %arg0 atomic syncscope("workgroup") release nontemporal alignment = 4 : i32, !ptr.ptr<#llvm.address_space<0>>
llvm.return
}
+
+// CHECK-LABEL: define <4 x float> @gather_ops
+// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], <4 x i1> %[[MASK:.*]], <4 x float> %[[PASSTHROUGH:.*]]) {
+// CHECK-NEXT: %[[V0:.*]] = call <4 x float> @llvm.masked.gather.v4f32.v4p0(<4 x ptr> %[[PTRS]], i32 1, <4 x i1> %[[MASK]], <4 x float> %[[PASSTHROUGH]])
+// CHECK-NEXT: %[[V1:.*]] = call <4 x float> @llvm.masked.gather.v4f32.v4p0(<4 x ptr> %[[PTRS]], i32 4, <4 x i1> %[[MASK]], <4 x float> %[[PASSTHROUGH]])
+// CHECK-NEXT: ret <4 x float> %[[V0]]
+// CHECK-NEXT: }
+llvm.func @gather_ops(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %mask: vector<4xi1>, %passthrough: vector<4xf32>) -> vector<4xf32> {
+ // Basic gather
+ %0 = ptr.gather %ptrs, %mask, %passthrough : vector<4x!ptr.ptr<#llvm.address_space<0>>> -> vector<4xf32>
+ // Gather with alignment
+ %1 = ptr.gather %ptrs, %mask, %passthrough alignment = 4 : vector<4x!ptr.ptr<#llvm.address_space<0>>> -> vector<4xf32>
+ llvm.return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x i32> @gather_ops_i32
+// CHECK-SAME: (<8 x ptr> %[[PTRS:.*]], <8 x i1> %[[MASK:.*]], <8 x i32> %[[PASSTHROUGH:.*]]) {
+// CHECK-NEXT: %[[V0:.*]] = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %[[PTRS]], i32 8, <8 x i1> %[[MASK]], <8 x i32> %[[PASSTHROUGH]])
+// CHECK-NEXT: ret <8 x i32> %[[V0]]
+// CHECK-NEXT: }
+llvm.func @gather_ops_i32(%ptrs: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %mask: vector<8xi1>, %passthrough: vector<8xi32>) -> vector<8xi32> {
+ %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32>
+ llvm.return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: define <4 x float> @masked_load_ops
+// CHECK-SAME: (ptr %[[PTR:.*]], <4 x i1> %[[MASK:.*]], <4 x float> %[[PASSTHROUGH:.*]]) {
+// CHECK-NEXT: %[[V0:.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %[[PTR]], i32 1, <4 x i1> %[[MASK]], <4 x float> %[[PASSTHROUGH]])
+// CHECK-NEXT: %[[V1:.*]] = call <4 x float> @llvm.masked.load.v4f32.p0(ptr %[[PTR]], i32 16, <4 x i1> %[[MASK]], <4 x float> %[[PASSTHROUGH]])
+// CHECK-NEXT: ret <4 x float> %[[V0]]
+// CHECK-NEXT: }
+llvm.func @masked_load_ops(%ptr: !ptr.ptr<#llvm.address_space<0>>, %mask: vector<4xi1>, %passthrough: vector<4xf32>) -> vector<4xf32> {
+ // Basic masked load
+ %0 = ptr.masked_load %ptr, %mask, %passthrough : !ptr.ptr<#llvm.address_space<0>> -> vector<4xf32>
+ // Masked load with alignment
+ %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 16 : !ptr.ptr<#llvm.address_space<0>> -> vector<4xf32>
+ llvm.return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x i64> @masked_load_ops_i64
+// CHECK-SAME: (ptr %[[PTR:.*]], <8 x i1> %[[MASK:.*]], <8 x i64> %[[PASSTHROUGH:.*]]) {
+// CHECK-NEXT: %[[V0:.*]] = call <8 x i64> @llvm.masked.load.v8i64.p0(ptr %[[PTR]], i32 8, <8 x i1> %[[MASK]], <8 x i64> %[[PASSTHROUGH]])
+// CHECK-NEXT: ret <8 x i64> %[[V0]]
+// CHECK-NEXT: }
+llvm.func @masked_load_ops_i64(%ptr: !ptr.ptr<#llvm.address_space<0>>, %mask: vector<8xi1>, %passthrough: vector<8xi64>) -> vector<8xi64> {
+ %0 = ptr.masked_load %ptr, %mask, %passthrough alignment = 8 : !ptr.ptr<#llvm.address_space<0>> -> vector<8xi64>
+ llvm.return %0 : vector<8xi64>
+}
+
+// CHECK-LABEL: define void @masked_store_ops
+// CHECK-SAME: (ptr %[[PTR:.*]], <4 x float> %[[VALUE:.*]], <4 x i1> %[[MASK:.*]]) {
+// CHECK-NEXT: call void @llvm.masked.store.v4f32.p0(<4 x float> %[[VALUE]], ptr %[[PTR]], i32 1, <4 x i1> %[[MASK]])
+// CHECK-NEXT: call void @llvm.masked.store.v4f32.p0(<4 x float> %[[VALUE]], ptr %[[PTR]], i32 32, <4 x i1> %[[MASK]])
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
+llvm.func @masked_store_ops(%ptr: !ptr.ptr<#llvm.address_space<0>>, %value: vector<4xf32>, %mask: vector<4xi1>) {
+ // Basic masked store
+ ptr.masked_store %value, %ptr, %mask : vector<4xf32>, !ptr.ptr<#llvm.address_space<0>>
+ // Masked store with alignment
+ ptr.masked_store %value, %ptr, %mask alignment = 32 : vector<4xf32>, !ptr.ptr<#llvm.address_space<0>>
+ llvm.return
+}
+
+// CHECK-LABEL: define void @masked_store_ops_i16
+// CHECK-SAME: (ptr %[[PTR:.*]], <8 x i16> %[[VALUE:.*]], <8 x i1> %[[MASK:.*]]) {
+// CHECK-NEXT: call void @llvm.masked.store.v8i16.p0(<8 x i16> %[[VALUE]], ptr %[[PTR]], i32 4, <8 x i1> %[[MASK]])
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
+llvm.func @masked_store_ops_i16(%ptr: !ptr.ptr<#llvm.address_space<0>>, %value: vector<8xi16>, %mask: vector<8xi1>) {
+ ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<8xi16>, !ptr.ptr<#llvm.address_space<0>>
+ llvm.return
+}
+
+// CHECK-LABEL: define void @scatter_ops
+// CHECK-SAME: (<4 x float> %[[VALUE:.*]], <4 x ptr> %[[PTRS:.*]], <4 x i1> %[[MASK:.*]]) {
+// CHECK-NEXT: call void @llvm.masked.scatter.v4f32.v4p0(<4 x float> %[[VALUE]], <4 x ptr> %[[PTRS]], i32 1, <4 x i1> %[[MASK]])
+// CHECK-NEXT: call void @llvm.masked.scatter.v4f32.v4p0(<4 x float> %[[VALUE]], <4 x ptr> %[[PTRS]], i32 8, <4 x i1> %[[MASK]])
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
+llvm.func @scatter_ops(%value: vector<4xf32>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %mask: vector<4xi1>) {
+ // Basic scatter
+ ptr.scatter %value, %ptrs, %mask : vector<4xf32>, vector<4x!ptr.ptr<#llvm.address_space<0>>>
+ // Scatter with alignment
+ ptr.scatter %value, %ptrs, %mask alignment = 8 : vector<4xf32>, vector<4x!ptr.ptr<#llvm.address_space<0>>>
+ llvm.return
+}
+
+// CHECK-LABEL: define void @scatter_ops_i64
+// CHECK-SAME: (<8 x i64> %[[VALUE:.*]], <8 x ptr> %[[PTRS:.*]], <8 x i1> %[[MASK:.*]]) {
+// CHECK-NEXT: call void @llvm.masked.scatter.v8i64.v8p0(<8 x i64> %[[VALUE]], <8 x ptr> %[[PTRS]], i32 16, <8 x i1> %[[MASK]])
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
+llvm.func @scatter_ops_i64(%value: vector<8xi64>, %ptrs: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %mask: vector<8xi1>) {
+ ptr.scatter %value, %ptrs, %mask alignment = 16 : vector<8xi64>, vector<8x!ptr.ptr<#llvm.address_space<0>>>
+ llvm.return
+}
+
+// CHECK-LABEL: define void @mixed_masked_ops_address_spaces
+// CHECK-SAME: (ptr addrspace(3) %[[PTR_SHARED:.*]], <4 x ptr addrspace(3)> %[[PTRS_SHARED:.*]], <4 x i1> %[[MASK:.*]], <4 x double> %[[VALUE_F64:.*]], <4 x double> %[[PASSTHROUGH_F64:.*]]) {
+// CHECK-NEXT: %[[V0:.*]] = call <4 x double> @llvm.masked.gather.v4f64.v4p3(<4 x ptr addrspace(3)> %[[PTRS_SHARED]], i32 8, <4 x i1> %[[MASK]], <4 x double> %[[PASSTHROUGH_F64]])
+// CHECK-NEXT: call void @llvm.masked.scatter.v4f64.v4p3(<4 x double> %[[VALUE_F64]], <4 x ptr addrspace(3)> %[[PTRS_SHARED]], i32 8, <4 x i1> %[[MASK]])
+// CHECK-NEXT: %[[V1:.*]] = call <4 x double> @llvm.masked.load.v4f64.p3(ptr addrspace(3) %[[PTR_SHARED]], i32 8, <4 x i1> %[[MASK]], <4 x double> %[[PASSTHROUGH_F64]])
+// CHECK-NEXT: call void @llvm.masked.store.v4f64.p3(<4 x double> %[[VALUE_F64]], ptr addrspace(3) %[[PTR_SHARED]], i32 8, <4 x i1> %[[MASK]])
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
+llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
+ %mask: vector<4xi1>, %value: vector<4xf64>, %passthrough: vector<4xf64>) {
+ // Test with shared memory address space (3) and f64 elements
+ %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf64>
+ ptr.scatter %value, %ptrs, %mask alignment = 8 : vector<4xf64>, vector<4x!ptr.ptr<#llvm.address_space<3>>>
+ %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 8 : !ptr.ptr<#llvm.address_space<3>> -> vector<4xf64>
+ ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>>
+ llvm.return
+}
>From 18686db9ff45b29d09d78c8dd06896ecfa23e807 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Mon, 1 Sep 2025 19:33:23 +0000
Subject: [PATCH 2/3] address comments
---
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 2 +-
.../LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp | 10 ----------
2 files changed, 1 insertion(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 170513d57c7be..59eaaf7c55cce 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -316,7 +316,7 @@ def Ptr_MaskedLoadOp : Pointer_Op<"masked_load", [
attr-dict `:` qualified(type($ptr)) `->` type($result)
}];
let builders = [
- OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask,
+ OpBuilder<(ins "Type":$resultType, "Value":$ptr, "Value":$mask,
"Value":$passthrough, CArg<"unsigned", "0">:$alignment)>
];
let hasVerifier = 1;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index 3bcc16927b352..d777667022a98 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -273,11 +273,6 @@ convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder,
if (!value || !ptr || !mask)
return maskedStoreOp.emitError("Failed to lookup operands");
- // Get the value type.
- llvm::Type *valueType = value->getType();
- if (!valueType)
- return maskedStoreOp.emitError("Failed to get value type");
-
// Get the alignment.
llvm::MaybeAlign alignment(maskedStoreOp.getAlignment().value_or(0));
@@ -297,11 +292,6 @@ convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
if (!value || !ptrs || !mask)
return scatterOp.emitError("Failed to lookup operands");
- // Get the value type
- llvm::Type *valueType = value->getType();
- if (!valueType)
- return scatterOp.emitError("Failed to get value type");
-
// Get the alignment.
llvm::MaybeAlign alignment(scatterOp.getAlignment().value_or(0));
>From 16785abcc2b0409c526d867538ed73bea4626c99 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Wed, 3 Sep 2025 14:30:13 +0000
Subject: [PATCH 3/3] address comments
---
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 12 ++----------
mlir/include/mlir/IR/OpBase.td | 4 ++++
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 8 ++++----
3 files changed, 10 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 59eaaf7c55cce..5939c3646884c 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -49,14 +49,6 @@ def Ptr_Mask1DType :
def Ptr_Ptr1DType :
Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
-// Gets the type ID of a type.
-class TypeIDType<string name> :
- StrFunc<"$" # name # ".getType().getTypeID()">;
-
-// Checks that all type IDs match.
-class AllTypeIDsMatch<list<string> names> :
- AllMatchSameOperatorTrait<names, TypeIDType<"_self">.result, "type IDs">;
-
//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//
@@ -139,7 +131,7 @@ def Ptr_GatherOp : Pointer_Op<"gather", [
let results = (outs Ptr_Any1DType:$result);
let assemblyFormat = [{
$ptrs `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
- attr-dict `:` qualified(type($ptrs)) `->` type($result)
+ attr-dict `:` type($ptrs) `->` type($result)
}];
let builders = [
OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask,
@@ -409,7 +401,7 @@ def Ptr_ScatterOp : Pointer_Op<"scatter", [
AlignmentProp:$alignment);
let assemblyFormat = [{
$value `,` $ptrs `,` $mask (`alignment` `=` $alignment^)?
- attr-dict `:` type($value) `,` qualified(type($ptrs))
+ attr-dict `:` type($value) `,` type($ptrs)
}];
let builders = [
OpBuilder<(ins "Value":$value, "Value":$ptrs, "Value":$mask,
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index af8c072a7a364..8d7dafae0ee76 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -556,6 +556,10 @@ class AllShapesMatch<list<string> names> :
class AllTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
+// Checks that all type IDs match.
+class AllTypeIDsMatch<list<string> names> :
+ AllMatchSameOperatorTrait<names, "$_self.getType().getTypeID()", "type IDs">;
+
// A type constraint that verifies that a shaped type matches the size and
// element type of a container with element types. More specifically, it denotes
// shapedArg.getType().getNumElements() == elementsArg.size() &&
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 74e35f48b033c..92ce9be97dd2c 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -120,7 +120,7 @@ LogicalResult GatherOp::verify() {
cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
DataLayout dataLayout = DataLayout::closest(*this);
if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
- getAlignment(), dataLayout, emitDiag))
+ getAlignment(), &dataLayout, emitDiag))
return failure();
// Verify the alignment.
@@ -212,7 +212,7 @@ LogicalResult MaskedLoadOp::verify() {
MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
DataLayout dataLayout = DataLayout::closest(*this);
if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
- getAlignment(), dataLayout, emitDiag))
+ getAlignment(), &dataLayout, emitDiag))
return failure();
// Verify the alignment.
@@ -243,7 +243,7 @@ LogicalResult MaskedStoreOp::verify() {
MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
DataLayout dataLayout = DataLayout::closest(*this);
if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
- getAlignment(), dataLayout, emitDiag))
+ getAlignment(), &dataLayout, emitDiag))
return failure();
// Verify the alignment.
@@ -276,7 +276,7 @@ LogicalResult ScatterOp::verify() {
cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
DataLayout dataLayout = DataLayout::closest(*this);
if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
- getAlignment(), dataLayout, emitDiag))
+ getAlignment(), &dataLayout, emitDiag))
return failure();
// Verify the alignment.
More information about the Mlir-commits
mailing list