[Mlir-commits] [mlir] [mlir][spirv] Add OpTypeSampler and OpSampledImage support (PR #189891)
Bryson Miller
llvmlistbot at llvm.org
Tue Apr 7 07:05:34 PDT 2026
https://github.com/abm-77 updated https://github.com/llvm/llvm-project/pull/189891
>From b244b4a4d50f420fd214597a20c5feedeba03e4e Mon Sep 17 00:00:00 2001
From: abm-77 <andrewmiller77 at protonmail.com>
Date: Tue, 31 Mar 2026 23:34:11 -0700
Subject: [PATCH] [mlir][spirv] Add OpTypeSampler and OpSampledImage support
Add the missing !spirv.sampler type (OpTypeSampler, opcode 26) and
spirv.SampledImage op (OpSampledImage, opcode 86) to MLIR's SPIR-V
dialect. This closes a gap where sampling ops like
spirv.ImageSampleImplicitLod consume sampled image values but there
was no way to create one from a separate image and sampler.
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 ++++-
.../mlir/Dialect/SPIRV/IR/SPIRVImageOps.td | 48 +++++++++++++++++++
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 10 ++++
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 8 +++-
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 13 ++++-
.../SPIRV/Deserialization/DeserializeOps.cpp | 1 +
.../SPIRV/Deserialization/Deserializer.cpp | 11 +++++
.../SPIRV/Deserialization/Deserializer.h | 2 +
.../Target/SPIRV/Serialization/Serializer.cpp | 5 ++
mlir/test/Dialect/SPIRV/IR/composite-ops.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/image-ops.mlir | 44 +++++++++++++++++
mlir/test/Dialect/SPIRV/IR/types.mlir | 9 ++++
mlir/test/Target/SPIRV/image-ops.mlir | 10 ++++
mlir/test/Target/SPIRV/sampled-image.mlir | 3 ++
14 files changed, 171 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 11a91958d7484..c4d123e0f539c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4255,6 +4255,7 @@ def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)">
def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
+def SPIRV_IsSamplerType : CPred<"::llvm::isa<::mlir::spirv::SamplerType>($_self)">;
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
@@ -4298,6 +4299,8 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
"any SPIR-V struct type">;
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
"any SPIR-V sampled image type">;
+def SPIRV_AnySampler : DialectType<SPIRV_Dialect, SPIRV_IsSamplerType,
+ "any SPIR-V sampler type">;
def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
"any SPIR-V tensorArm type">;
@@ -4311,7 +4314,7 @@ def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
- SPIRV_AnyImage, SPIRV_AnyTensorArm
+ SPIRV_AnySampler, SPIRV_AnyImage, SPIRV_AnyTensorArm
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4412,6 +4415,7 @@ def SPIRV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 2
def SPIRV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
def SPIRV_OC_OpTypeMatrix : I32EnumAttrCase<"OpTypeMatrix", 24>;
def SPIRV_OC_OpTypeImage : I32EnumAttrCase<"OpTypeImage", 25>;
+def SPIRV_OC_OpTypeSampler : I32EnumAttrCase<"OpTypeSampler", 26>;
def SPIRV_OC_OpTypeSampledImage : I32EnumAttrCase<"OpTypeSampledImage", 27>;
def SPIRV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
def SPIRV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
@@ -4449,6 +4453,7 @@ def SPIRV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeCons
def SPIRV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPIRV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>;
def SPIRV_OC_OpTranspose : I32EnumAttrCase<"OpTranspose", 84>;
+def SPIRV_OC_OpSampledImage : I32EnumAttrCase<"OpSampledImage", 86>;
def SPIRV_OC_OpImageSampleImplicitLod : I32EnumAttrCase<"OpImageSampleImplicitLod", 87>;
def SPIRV_OC_OpImageSampleExplicitLod : I32EnumAttrCase<"OpImageSampleExplicitLod", 88>;
def SPIRV_OC_OpImageSampleProjDrefImplicitLod : I32EnumAttrCase<"OpImageSampleProjDrefImplicitLod", 93>;
@@ -4661,7 +4666,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpMemoryModel, SPIRV_OC_OpEntryPoint, SPIRV_OC_OpExecutionMode,
SPIRV_OC_OpCapability, SPIRV_OC_OpTypeVoid, SPIRV_OC_OpTypeBool,
SPIRV_OC_OpTypeInt, SPIRV_OC_OpTypeFloat, SPIRV_OC_OpTypeVector,
- SPIRV_OC_OpTypeMatrix, SPIRV_OC_OpTypeImage, SPIRV_OC_OpTypeSampledImage,
+ SPIRV_OC_OpTypeMatrix, SPIRV_OC_OpTypeImage, SPIRV_OC_OpTypeSampler,
+ SPIRV_OC_OpTypeSampledImage,
SPIRV_OC_OpTypeArray, SPIRV_OC_OpTypeRuntimeArray, SPIRV_OC_OpTypeStruct,
SPIRV_OC_OpTypePointer, SPIRV_OC_OpTypeFunction, SPIRV_OC_OpTypeForwardPointer,
SPIRV_OC_OpConstantTrue, SPIRV_OC_OpConstantFalse, SPIRV_OC_OpConstant,
@@ -4677,6 +4683,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpVectorInsertDynamic, SPIRV_OC_OpVectorShuffle,
SPIRV_OC_OpCompositeConstruct, SPIRV_OC_OpCompositeExtract,
SPIRV_OC_OpCompositeInsert, SPIRV_OC_OpTranspose,
+ SPIRV_OC_OpSampledImage,
SPIRV_OC_OpImageSampleImplicitLod, SPIRV_OC_OpImageSampleExplicitLod,
SPIRV_OC_OpImageSampleProjDrefImplicitLod, SPIRV_OC_OpImageFetch,
SPIRV_OC_OpImageDrefGather,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td
index e23efa57e5e53..e5ff4c5d96b4a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td
@@ -65,6 +65,54 @@ def SPIRV_SampledImageTransform : StrFunc<"llvm::cast<::mlir::spirv::SampledImag
// -----
+def SPIRV_SampledImageOp : SPIRV_Op<"SampledImage",
+ [Pure,
+ TypesMatchWith<"type of 'result' wraps the image type of 'image'",
+ "result", "image",
+ "::llvm::cast<spirv::SampledImageType>($_self).getImageType()">,
+ SPIRV_DimIsNot<"image", ["SubpassData"]>,
+ SPIRV_SampledOperandIs<"image", ["SamplerUnknown", "NeedSampler"]>]> {
+ let summary = "Create a sampled image, containing both a sampler and an image.";
+
+ let description = [{
+ Result Type must be OpTypeSampledImage whose Image Type is the same as
+ the type of the Image operand.
+
+ Image must be an object whose type is an OpTypeImage, whose Sampled
+ operand is 0 or 1. The Dim operand of the underlying OpTypeImage must
+ not be SubpassData. Additionally, starting with version 1.6, the Dim
+ operand must not be Buffer.
+
+ Sampler must be an object of a type made by OpTypeSampler.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %0 = spirv.SampledImage %image, %sampler : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_AnyImage:$image,
+ SPIRV_AnySampler:$sampler
+ );
+
+ let results = (outs
+ SPIRV_AnySampledImage:$result
+ );
+
+ let assemblyFormat = [{
+ $image `,` $sampler attr-dict `:` type($image) `,` type($sampler)
+ `->` type($result)
+ }];
+
+ let hasVerifier = 0;
+}
+
+// -----
+
def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather",
[Pure,
SPIRV_DimIs<"sampled_image", ["Dim2D", "Cube", "Rect"], SPIRV_SampledImageTransform.result>,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 4a0c29d4b5d90..9864f644aa93e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -228,6 +228,16 @@ class SampledImageType
Type getImageType() const;
};
+// SPIR-V sampler type
+class SamplerType : public Type::TypeBase<SamplerType, SPIRVType, TypeStorage> {
+public:
+ using Base::Base;
+
+ static constexpr StringLiteral name = "spirv.sampler";
+
+ static SamplerType get(MLIRContext *context);
+};
+
/// SPIR-V struct type. Two kinds of struct types are supported:
/// - Literal: a literal struct type is uniqued by its fields (types + offset
/// info + decoration info).
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 5782b42dba026..036d48f0fd637 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -861,6 +861,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseRuntimeArrayType(*this, parser);
if (keyword == "sampled_image")
return parseSampledImageType(*this, parser);
+ if (keyword == "sampler")
+ return SamplerType::get(getContext());
if (keyword == "struct")
return parseStructType(*this, parser);
if (keyword == "matrix")
@@ -907,6 +909,8 @@ static void print(SampledImageType type, DialectAsmPrinter &os) {
os << "sampled_image<" << type.getImageType() << ">";
}
+static void print(SamplerType type, DialectAsmPrinter &os) { os << "sampler"; }
+
static void print(StructType type, DialectAsmPrinter &os) {
FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
@@ -1001,8 +1005,8 @@ static void print(TensorArmType type, DialectAsmPrinter &os) {
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
- ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
- [&](auto type) { print(type, os); })
+ ImageType, SampledImageType, SamplerType, StructType, MatrixType,
+ TensorArmType>([&](auto type) { print(type, os); })
.DefaultUnreachable("Unhandled SPIR-V type");
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 331d98c1d9313..c4dd4cea778d7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -57,6 +57,7 @@ class TypeExtensionVisitor {
for (Type elementType : concreteType.getElementTypes())
add(elementType);
})
+ .Case<SamplerType>([](auto) { /* no extensions */ })
.DefaultUnreachable("Unhandled type");
}
@@ -107,6 +108,7 @@ class TypeCapabilityVisitor {
for (Type elementType : concreteType.getElementTypes())
add(elementType);
})
+ .Case<SamplerType>([](auto) { /* no capabilities */ })
.DefaultUnreachable("Unhandled type");
}
@@ -794,6 +796,14 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// SamplerType
+//===----------------------------------------------------------------------===//
+
+SamplerType SamplerType::get(MLIRContext *context) {
+ return Base::get(context);
+}
+
//===----------------------------------------------------------------------===//
// StructType
//===----------------------------------------------------------------------===//
@@ -1331,5 +1341,6 @@ TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
void SPIRVDialect::registerTypes() {
addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
- RuntimeArrayType, SampledImageType, StructType, TensorArmType>();
+ RuntimeArrayType, SampledImageType, SamplerType, StructType,
+ TensorArmType>();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index cc6302126d64a..0faa5f0f29d7d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -183,6 +183,7 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeArray:
case spirv::Opcode::OpTypeFunction:
case spirv::Opcode::OpTypeImage:
+ case spirv::Opcode::OpTypeSampler:
case spirv::Opcode::OpTypeSampledImage:
case spirv::Opcode::OpTypeRuntimeArray:
case spirv::Opcode::OpTypeStruct:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index f98236c5daece..9557b4647958d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1150,6 +1150,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processFunctionType(operands);
case spirv::Opcode::OpTypeImage:
return processImageType(operands);
+ case spirv::Opcode::OpTypeSampler:
+ return processSamplerType(operands);
case spirv::Opcode::OpTypeSampledImage:
return processSampledImageType(operands);
case spirv::Opcode::OpTypeRuntimeArray:
@@ -1634,6 +1636,15 @@ spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processSamplerType(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 1)
+ return emitError(unknownLoc, "OpTypeSampler must have no parameters");
+
+ typeMap[operands[0]] = spirv::SamplerType::get(context);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 50c935036158c..e0743503acc5b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -317,6 +317,8 @@ class Deserializer {
LogicalResult processSampledImageType(ArrayRef<uint32_t> operands);
+ LogicalResult processSamplerType(ArrayRef<uint32_t> operands);
+
LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
LogicalResult processStructType(ArrayRef<uint32_t> operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c21cb27b072f1..aaa80470f40e1 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -728,6 +728,11 @@ LogicalResult Serializer::prepareBasicType(
return processTypeDecoration(loc, runtimeArrayType, resultID);
}
+ if (isa<spirv::SamplerType>(type)) {
+ typeEnum = spirv::Opcode::OpTypeSampler;
+ return success();
+ }
+
if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
typeEnum = spirv::Opcode::OpTypeSampledImage;
uint32_t imageTypeID = 0;
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 6e4126172f670..ef04b949c5219 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -100,7 +100,7 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
// -----
func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> {
- // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
+ // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V sampler type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
%0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1>
return %0: vector<4x2xi1>
}
diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
index 12b5f2ce62a68..c4f90cfcb9a49 100644
--- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
@@ -468,3 +468,47 @@ func.func @gard_too_many_args(%arg0 : !spirv.sampled_image<!spirv.image<f32, Dim
%0 = spirv.ImageSampleExplicitLod %arg0, %arg1 ["Grad"], %arg2, %arg2, %arg2 : !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>>, vector<2xf32>, vector<2xf32>, vector<2xf32>, vector<2xf32> -> vector<4xf32>
spirv.Return
}
+
+//===----------------------------------------------------------------------===//
+// spirv.SampledImage
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @sampled_image(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, %arg1 : !spirv.sampler) -> () {
+ // CHECK: spirv.SampledImage {{%.*}}, {{%.*}} : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+ %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+ spirv.Return
+}
+
+// -----
+
+func.func @sampled_image_sampler_unknown(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>, %arg1 : !spirv.sampler) -> () {
+ // CHECK: spirv.SampledImage {{%.*}}, {{%.*}} : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>
+ %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>
+ spirv.Return
+}
+
+// -----
+
+func.func @sampled_image_error(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, %arg1 : !spirv.sampler) -> () {
+ // expected-error @+1 {{type of 'result' wraps the image type of 'image'}}
+ %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+ spirv.Return
+}
+
+// -----
+
+func.func @sampled_image_dim_subpassdata(%arg0 : !spirv.image<f32, SubpassData, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, %arg1 : !spirv.sampler) -> () {
+ // expected-error @+1 {{sampled image Dim must not be SubpassData or Buffer, got SubpassData}}
+ %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, SubpassData, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, SubpassData, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+ spirv.Return
+}
+
+// -----
+
+func.func @sampled_image_sampled_operand(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, %arg1 : !spirv.sampler) -> () {
+ // expected-error @+1 {{the sampled operand of the underlying image must be SamplerUnknown or NeedSampler}}
+ %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>
+ spirv.Return
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 710673b73cee5..99443a13e0ec3 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -234,6 +234,15 @@ func.func private @image_parameters_nocomma_5(!spirv.image<f32, Dim1D, NoDepth,
// -----
+//===----------------------------------------------------------------------===//
+// SamplerType
+//===----------------------------------------------------------------------===//
+
+// CHECK: func private @sampler_type(!spirv.sampler)
+func.func private @sampler_type(!spirv.sampler) -> ()
+
+// -----
+
//===----------------------------------------------------------------------===//
// SampledImageType
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/image-ops.mlir b/mlir/test/Target/SPIRV/image-ops.mlir
index 3593d9b0e9b38..b664561ce9b01 100644
--- a/mlir/test/Target/SPIRV/image-ops.mlir
+++ b/mlir/test/Target/SPIRV/image-ops.mlir
@@ -61,3 +61,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, StorageImageWrit
spirv.Return
}
}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+ spirv.func @sampled_image(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, %arg1 : !spirv.sampler) "None" {
+ // CHECK: {{%.*}} = spirv.SampledImage {{%.*}}, {{%.*}} : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+ %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+ spirv.Return
+ }
+}
diff --git a/mlir/test/Target/SPIRV/sampled-image.mlir b/mlir/test/Target/SPIRV/sampled-image.mlir
index ff068208540f4..4f6f3256acbac 100644
--- a/mlir/test/Target/SPIRV/sampled-image.mlir
+++ b/mlir/test/Target/SPIRV/sampled-image.mlir
@@ -14,4 +14,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Sampled1D, Sampl
// CHECK: !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Rect, DepthUnknown, Arrayed, MultiSampled, NeedSampler, R8ui>>, UniformConstant>
spirv.GlobalVariable @var2 bind(0, 0) : !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Rect, DepthUnknown, Arrayed, MultiSampled, NeedSampler, R8ui>>, UniformConstant>
+
+ // CHECK: !spirv.ptr<!spirv.sampler, UniformConstant>
+ spirv.GlobalVariable @var3 bind(0, 2) : !spirv.ptr<!spirv.sampler, UniformConstant>
}
More information about the Mlir-commits
mailing list