[Mlir-commits] [mlir] [mlir][spirv] Add OpTypeSampler and OpSampledImage support (PR #189891)

Bryson Miller llvmlistbot at llvm.org
Tue Apr 7 07:05:34 PDT 2026


https://github.com/abm-77 updated https://github.com/llvm/llvm-project/pull/189891

>From b244b4a4d50f420fd214597a20c5feedeba03e4e Mon Sep 17 00:00:00 2001
From: abm-77 <andrewmiller77 at protonmail.com>
Date: Tue, 31 Mar 2026 23:34:11 -0700
Subject: [PATCH] [mlir][spirv] Add OpTypeSampler and OpSampledImage support

Add the missing !spirv.sampler type (OpTypeSampler, opcode 26) and
spirv.SampledImage op (OpSampledImage, opcode 86) to MLIR's SPIR-V
dialect. This closes a gap where sampling ops like
spirv.ImageSampleImplicitLod consume sampled image values but there
was no way to create one from a separate image and sampler.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 11 ++++-
 .../mlir/Dialect/SPIRV/IR/SPIRVImageOps.td    | 48 +++++++++++++++++++
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        | 10 ++++
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  8 +++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 13 ++++-
 .../SPIRV/Deserialization/DeserializeOps.cpp  |  1 +
 .../SPIRV/Deserialization/Deserializer.cpp    | 11 +++++
 .../SPIRV/Deserialization/Deserializer.h      |  2 +
 .../Target/SPIRV/Serialization/Serializer.cpp |  5 ++
 mlir/test/Dialect/SPIRV/IR/composite-ops.mlir |  2 +-
 mlir/test/Dialect/SPIRV/IR/image-ops.mlir     | 44 +++++++++++++++++
 mlir/test/Dialect/SPIRV/IR/types.mlir         |  9 ++++
 mlir/test/Target/SPIRV/image-ops.mlir         | 10 ++++
 mlir/test/Target/SPIRV/sampled-image.mlir     |  3 ++
 14 files changed, 171 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 11a91958d7484..c4d123e0f539c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4255,6 +4255,7 @@ def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)">
 def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
 def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
 def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
+def SPIRV_IsSamplerType : CPred<"::llvm::isa<::mlir::spirv::SamplerType>($_self)">;
 def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
 def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
 
@@ -4298,6 +4299,8 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
                                 "any SPIR-V struct type">;
 def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
                                 "any SPIR-V sampled image type">;
+def SPIRV_AnySampler : DialectType<SPIRV_Dialect, SPIRV_IsSamplerType,
+                                "any SPIR-V sampler type">;
 def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
                                  "any SPIR-V tensorArm type">;
 
@@ -4311,7 +4314,7 @@ def SPIRV_Type : AnyTypeOf<[
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
-    SPIRV_AnyImage, SPIRV_AnyTensorArm
+    SPIRV_AnySampler, SPIRV_AnyImage, SPIRV_AnyTensorArm
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4412,6 +4415,7 @@ def SPIRV_OC_OpTypeFloat                      : I32EnumAttrCase<"OpTypeFloat", 2
 def SPIRV_OC_OpTypeVector                     : I32EnumAttrCase<"OpTypeVector", 23>;
 def SPIRV_OC_OpTypeMatrix                     : I32EnumAttrCase<"OpTypeMatrix", 24>;
 def SPIRV_OC_OpTypeImage                      : I32EnumAttrCase<"OpTypeImage", 25>;
+def SPIRV_OC_OpTypeSampler                    : I32EnumAttrCase<"OpTypeSampler", 26>;
 def SPIRV_OC_OpTypeSampledImage               : I32EnumAttrCase<"OpTypeSampledImage", 27>;
 def SPIRV_OC_OpTypeArray                      : I32EnumAttrCase<"OpTypeArray", 28>;
 def SPIRV_OC_OpTypeRuntimeArray               : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
@@ -4449,6 +4453,7 @@ def SPIRV_OC_OpCompositeConstruct             : I32EnumAttrCase<"OpCompositeCons
 def SPIRV_OC_OpCompositeExtract               : I32EnumAttrCase<"OpCompositeExtract", 81>;
 def SPIRV_OC_OpCompositeInsert                : I32EnumAttrCase<"OpCompositeInsert", 82>;
 def SPIRV_OC_OpTranspose                      : I32EnumAttrCase<"OpTranspose", 84>;
+def SPIRV_OC_OpSampledImage                   : I32EnumAttrCase<"OpSampledImage", 86>;
 def SPIRV_OC_OpImageSampleImplicitLod         : I32EnumAttrCase<"OpImageSampleImplicitLod", 87>;
 def SPIRV_OC_OpImageSampleExplicitLod         : I32EnumAttrCase<"OpImageSampleExplicitLod", 88>;
 def SPIRV_OC_OpImageSampleProjDrefImplicitLod : I32EnumAttrCase<"OpImageSampleProjDrefImplicitLod", 93>;
@@ -4661,7 +4666,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpMemoryModel, SPIRV_OC_OpEntryPoint, SPIRV_OC_OpExecutionMode,
       SPIRV_OC_OpCapability, SPIRV_OC_OpTypeVoid, SPIRV_OC_OpTypeBool,
       SPIRV_OC_OpTypeInt, SPIRV_OC_OpTypeFloat, SPIRV_OC_OpTypeVector,
-      SPIRV_OC_OpTypeMatrix, SPIRV_OC_OpTypeImage, SPIRV_OC_OpTypeSampledImage,
+      SPIRV_OC_OpTypeMatrix, SPIRV_OC_OpTypeImage, SPIRV_OC_OpTypeSampler,
+      SPIRV_OC_OpTypeSampledImage,
       SPIRV_OC_OpTypeArray, SPIRV_OC_OpTypeRuntimeArray, SPIRV_OC_OpTypeStruct,
       SPIRV_OC_OpTypePointer, SPIRV_OC_OpTypeFunction, SPIRV_OC_OpTypeForwardPointer,
       SPIRV_OC_OpConstantTrue, SPIRV_OC_OpConstantFalse, SPIRV_OC_OpConstant,
@@ -4677,6 +4683,7 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpVectorInsertDynamic, SPIRV_OC_OpVectorShuffle,
       SPIRV_OC_OpCompositeConstruct, SPIRV_OC_OpCompositeExtract,
       SPIRV_OC_OpCompositeInsert, SPIRV_OC_OpTranspose,
+      SPIRV_OC_OpSampledImage,
       SPIRV_OC_OpImageSampleImplicitLod, SPIRV_OC_OpImageSampleExplicitLod,
       SPIRV_OC_OpImageSampleProjDrefImplicitLod, SPIRV_OC_OpImageFetch,
       SPIRV_OC_OpImageDrefGather,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td
index e23efa57e5e53..e5ff4c5d96b4a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td
@@ -65,6 +65,54 @@ def SPIRV_SampledImageTransform : StrFunc<"llvm::cast<::mlir::spirv::SampledImag
 
 // -----
 
+def SPIRV_SampledImageOp : SPIRV_Op<"SampledImage",
+    [Pure,
+     TypesMatchWith<"type of 'result' wraps the image type of 'image'",
+                    "result", "image",
+                    "::llvm::cast<spirv::SampledImageType>($_self).getImageType()">,
+     SPIRV_DimIsNot<"image", ["SubpassData"]>,
+     SPIRV_SampledOperandIs<"image", ["SamplerUnknown", "NeedSampler"]>]> {
+  let summary = "Create a sampled image, containing both a sampler and an image.";
+
+  let description = [{
+    Result Type must be OpTypeSampledImage whose Image Type is the same as
+    the type of the Image operand.
+
+    Image must be an object whose type is an OpTypeImage, whose Sampled
+    operand is 0 or 1. The Dim operand of the underlying OpTypeImage must
+    not be SubpassData. Additionally, starting with version 1.6, the Dim
+    operand must not be Buffer.
+
+    Sampler must be an object of a type made by OpTypeSampler.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.SampledImage %image, %sampler : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_AnyImage:$image,
+    SPIRV_AnySampler:$sampler
+  );
+
+  let results = (outs
+    SPIRV_AnySampledImage:$result
+  );
+
+  let assemblyFormat = [{
+    $image `,` $sampler attr-dict `:` type($image) `,` type($sampler)
+    `->` type($result)
+  }];
+
+  let hasVerifier = 0;
+}
+
+// -----
+
 def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", 
     [Pure,
      SPIRV_DimIs<"sampled_image", ["Dim2D", "Cube", "Rect"], SPIRV_SampledImageTransform.result>,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 4a0c29d4b5d90..9864f644aa93e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -228,6 +228,16 @@ class SampledImageType
   Type getImageType() const;
 };
 
+// SPIR-V sampler type
+class SamplerType : public Type::TypeBase<SamplerType, SPIRVType, TypeStorage> {
+public:
+  using Base::Base;
+
+  static constexpr StringLiteral name = "spirv.sampler";
+
+  static SamplerType get(MLIRContext *context);
+};
+
 /// SPIR-V struct type. Two kinds of struct types are supported:
 /// - Literal: a literal struct type is uniqued by its fields (types + offset
 /// info + decoration info).
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 5782b42dba026..036d48f0fd637 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -861,6 +861,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
     return parseRuntimeArrayType(*this, parser);
   if (keyword == "sampled_image")
     return parseSampledImageType(*this, parser);
+  if (keyword == "sampler")
+    return SamplerType::get(getContext());
   if (keyword == "struct")
     return parseStructType(*this, parser);
   if (keyword == "matrix")
@@ -907,6 +909,8 @@ static void print(SampledImageType type, DialectAsmPrinter &os) {
   os << "sampled_image<" << type.getImageType() << ">";
 }
 
+static void print(SamplerType type, DialectAsmPrinter &os) { os << "sampler"; }
+
 static void print(StructType type, DialectAsmPrinter &os) {
   FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
 
@@ -1001,8 +1005,8 @@ static void print(TensorArmType type, DialectAsmPrinter &os) {
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
       .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
-            ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
-          [&](auto type) { print(type, os); })
+            ImageType, SampledImageType, SamplerType, StructType, MatrixType,
+            TensorArmType>([&](auto type) { print(type, os); })
       .DefaultUnreachable("Unhandled SPIR-V type");
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 331d98c1d9313..c4dd4cea778d7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -57,6 +57,7 @@ class TypeExtensionVisitor {
           for (Type elementType : concreteType.getElementTypes())
             add(elementType);
         })
+        .Case<SamplerType>([](auto) { /* no extensions */ })
         .DefaultUnreachable("Unhandled type");
   }
 
@@ -107,6 +108,7 @@ class TypeCapabilityVisitor {
           for (Type elementType : concreteType.getElementTypes())
             add(elementType);
         })
+        .Case<SamplerType>([](auto) { /* no capabilities */ })
         .DefaultUnreachable("Unhandled type");
   }
 
@@ -794,6 +796,14 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SamplerType
+//===----------------------------------------------------------------------===//
+
+SamplerType SamplerType::get(MLIRContext *context) {
+  return Base::get(context);
+}
+
 //===----------------------------------------------------------------------===//
 // StructType
 //===----------------------------------------------------------------------===//
@@ -1331,5 +1341,6 @@ TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
 
 void SPIRVDialect::registerTypes() {
   addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
-           RuntimeArrayType, SampledImageType, StructType, TensorArmType>();
+           RuntimeArrayType, SampledImageType, SamplerType, StructType,
+           TensorArmType>();
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index cc6302126d64a..0faa5f0f29d7d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -183,6 +183,7 @@ LogicalResult spirv::Deserializer::processInstruction(
   case spirv::Opcode::OpTypeArray:
   case spirv::Opcode::OpTypeFunction:
   case spirv::Opcode::OpTypeImage:
+  case spirv::Opcode::OpTypeSampler:
   case spirv::Opcode::OpTypeSampledImage:
   case spirv::Opcode::OpTypeRuntimeArray:
   case spirv::Opcode::OpTypeStruct:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index f98236c5daece..9557b4647958d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1150,6 +1150,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     return processFunctionType(operands);
   case spirv::Opcode::OpTypeImage:
     return processImageType(operands);
+  case spirv::Opcode::OpTypeSampler:
+    return processSamplerType(operands);
   case spirv::Opcode::OpTypeSampledImage:
     return processSampledImageType(operands);
   case spirv::Opcode::OpTypeRuntimeArray:
@@ -1634,6 +1636,15 @@ spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+spirv::Deserializer::processSamplerType(ArrayRef<uint32_t> operands) {
+  if (operands.size() != 1)
+    return emitError(unknownLoc, "OpTypeSampler must have no parameters");
+
+  typeMap[operands[0]] = spirv::SamplerType::get(context);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Constant
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 50c935036158c..e0743503acc5b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -317,6 +317,8 @@ class Deserializer {
 
   LogicalResult processSampledImageType(ArrayRef<uint32_t> operands);
 
+  LogicalResult processSamplerType(ArrayRef<uint32_t> operands);
+
   LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
 
   LogicalResult processStructType(ArrayRef<uint32_t> operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c21cb27b072f1..aaa80470f40e1 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -728,6 +728,11 @@ LogicalResult Serializer::prepareBasicType(
     return processTypeDecoration(loc, runtimeArrayType, resultID);
   }
 
+  if (isa<spirv::SamplerType>(type)) {
+    typeEnum = spirv::Opcode::OpTypeSampler;
+    return success();
+  }
+
   if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
     typeEnum = spirv::Opcode::OpTypeSampledImage;
     uint32_t imageTypeID = 0;
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 6e4126172f670..ef04b949c5219 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -100,7 +100,7 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
 // -----
 
 func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> {
-  // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
+  // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V sampler type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
   %0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1>
   return %0: vector<4x2xi1>
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
index 12b5f2ce62a68..c4f90cfcb9a49 100644
--- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
@@ -468,3 +468,47 @@ func.func @gard_too_many_args(%arg0 : !spirv.sampled_image<!spirv.image<f32, Dim
   %0 = spirv.ImageSampleExplicitLod %arg0, %arg1 ["Grad"], %arg2, %arg2, %arg2 : !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>>, vector<2xf32>, vector<2xf32>, vector<2xf32>, vector<2xf32> -> vector<4xf32>
   spirv.Return
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.SampledImage
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @sampled_image(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, %arg1 : !spirv.sampler) -> () {
+  // CHECK: spirv.SampledImage {{%.*}}, {{%.*}} : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+  %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+  spirv.Return
+}
+
+// -----
+
+func.func @sampled_image_sampler_unknown(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>, %arg1 : !spirv.sampler) -> () {
+  // CHECK: spirv.SampledImage {{%.*}}, {{%.*}} : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>
+  %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>
+  spirv.Return
+}
+
+// -----
+
+func.func @sampled_image_error(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, %arg1 : !spirv.sampler) -> () {
+  // expected-error @+1 {{type of 'result' wraps the image type of 'image'}}
+  %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+  spirv.Return
+}
+
+// -----
+
+func.func @sampled_image_dim_subpassdata(%arg0 : !spirv.image<f32, SubpassData, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, %arg1 : !spirv.sampler) -> () {
+  // expected-error @+1 {{sampled image Dim must not be SubpassData or Buffer, got SubpassData}}
+  %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, SubpassData, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, SubpassData, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+  spirv.Return
+}
+
+// -----
+
+func.func @sampled_image_sampled_operand(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, %arg1 : !spirv.sampler) -> () {
+  // expected-error @+1 {{the sampled operand of the underlying image must be SamplerUnknown or NeedSampler}}
+  %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>
+  spirv.Return
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 710673b73cee5..99443a13e0ec3 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -234,6 +234,15 @@ func.func private @image_parameters_nocomma_5(!spirv.image<f32, Dim1D, NoDepth,
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// SamplerType
+//===----------------------------------------------------------------------===//
+
+// CHECK: func private @sampler_type(!spirv.sampler)
+func.func private @sampler_type(!spirv.sampler) -> ()
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // SampledImageType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/image-ops.mlir b/mlir/test/Target/SPIRV/image-ops.mlir
index 3593d9b0e9b38..b664561ce9b01 100644
--- a/mlir/test/Target/SPIRV/image-ops.mlir
+++ b/mlir/test/Target/SPIRV/image-ops.mlir
@@ -61,3 +61,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, StorageImageWrit
     spirv.Return
   }
 }
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+  spirv.func @sampled_image(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, %arg1 : !spirv.sampler) "None" {
+    // CHECK: {{%.*}} = spirv.SampledImage {{%.*}}, {{%.*}} : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+    %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, !spirv.sampler -> !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
+    spirv.Return
+  }
+}
diff --git a/mlir/test/Target/SPIRV/sampled-image.mlir b/mlir/test/Target/SPIRV/sampled-image.mlir
index ff068208540f4..4f6f3256acbac 100644
--- a/mlir/test/Target/SPIRV/sampled-image.mlir
+++ b/mlir/test/Target/SPIRV/sampled-image.mlir
@@ -14,4 +14,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Sampled1D, Sampl
 
   // CHECK: !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Rect, DepthUnknown, Arrayed, MultiSampled, NeedSampler, R8ui>>, UniformConstant>
   spirv.GlobalVariable @var2 bind(0, 0) : !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Rect, DepthUnknown, Arrayed, MultiSampled, NeedSampler, R8ui>>, UniformConstant>
+
+  // CHECK: !spirv.ptr<!spirv.sampler, UniformConstant>
+  spirv.GlobalVariable @var3 bind(0, 2) : !spirv.ptr<!spirv.sampler, UniformConstant>
 }



More information about the Mlir-commits mailing list