[Mlir-commits] [mlir] [mlir][spirv] Refactor image operations (PR #128552)

Igor Wodiany llvmlistbot at llvm.org
Tue Feb 25 03:06:29 PST 2025


https://github.com/IgWod-IMG updated https://github.com/llvm/llvm-project/pull/128552

>From a4a35bf23be1479de54a4cf2b7b9b82533027407 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Mon, 10 Feb 2025 16:30:28 +0000
Subject: [PATCH] [mlir][spirv] Refactor image operations

This patch makes multiple changes to images ops:

  1) The assembly format is unified with the rest of the dialect
     to use `%0 = spirv.op %1, %2, %3 : f32, f32, f32` rather than
     having each type directly attached to each argument.
  2) The verification is moved from `SPIRVOps.cpp` to a new file
     so the ops can be easier maintained.
  3) Majority of C++ verification is removed and moved into ODS.
     Verification of `ImageQuerySizeOp` is left in C++ due to the
     complexity of rules.
  4) `spirv::bitEnumContainsAll` is replaced by `spirv::bitEnumContainsAny`
     in `verifyImageOperands`. In this context `...Any` seems to be
     the correct function, as we want to check whether unsupported
     operand is being used - in opposite to checking if all unsupported
     operands are being used.
  5) Simplify target tests by removing entry points and adding
     `Linkage` capability to the modules.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVImageOps.td    | 108 +++++++----
 mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt      |   1 +
 mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp        | 137 ++++++++++++++
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 172 ------------------
 mlir/test/Dialect/SPIRV/IR/image-ops.mlir     |  44 ++---
 mlir/test/Target/SPIRV/image-ops.mlir         |  24 +--
 6 files changed, 245 insertions(+), 241 deletions(-)
 create mode 100644 mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td
index b7d6ec70ce141..a4fe29536e60a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file contains image ops for the SPIR-V dialect. It corresponds
-// to "3.37.10. Image Instructions" of the SPIR-V specification.
+// to "3.56.10. Image Instructions" of the SPIR-V specification.
 //
 //===----------------------------------------------------------------------===//
 
@@ -19,12 +19,56 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 
 // -----
 
-def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", [Pure]> {
+class SPIRV_ValuesAreContained<string operand, list<string> values, string transform, string type, string getter> :
+  CPred<"::llvm::is_contained("
+    "{::mlir::spirv::" # type # "::" # !interleave(values, ", ::mlir::spirv::" # type # "::") # "},"
+    "::llvm::cast<::mlir::spirv::ImageType>(" # !subst("$_self", "$" # operand # ".getType()", transform) # ")." # getter # "()"
+  ")"
+>;
+
+class SPIRV_SampledOperandIs<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
+  "the sampled operand of the underlying image must be " # !interleave(values, " or "),
+  SPIRV_ValuesAreContained<operand, values, transform, "ImageSamplerUseInfo", "getSamplerUseInfo"> 
+>;
+
+class SPIRV_MSOperandIs<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
+  "the MS operand of the underlying image type must be " # !interleave(values, " or "),
+  SPIRV_ValuesAreContained<operand, values, transform, "ImageSamplingInfo", "getSamplingInfo"> 
+>;
+
+class SPIRV_DimIs<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
+  "the Dim operand of the underlying image must be " # !interleave(values, " or "),
+  SPIRV_ValuesAreContained<operand, values, transform, "Dim", "getDim">
+>;
+
+class SPIRV_DimIsNot<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
+  "the Dim operand of the underlying image must not be " # !interleave(values, " or "),
+  Neg<SPIRV_ValuesAreContained<operand, values, transform, "Dim", "getDim">>
+>;
+
+class SPIRV_NoneOrElementMatchImage<string operand, string image, string transform="$_self"> : PredOpTrait<
+  "the " # operand # " component type must match the image sampled type",
+  CPred<"::llvm::isa<NoneType>(cast<ImageType>(" # !subst("$_self", "$" # image # ".getType()", transform) # ").getElementType()) ||"
+        "(getElementTypeOrSelf($" # operand # ")"
+          "=="
+        "cast<ImageType>(" # !subst("$_self", "$" # image # ".getType()", transform) # ").getElementType())"
+  >
+>;
+
+def SPIRV_SampledImageTransform : StrFunc<"llvm::cast<spirv::SampledImageType>($_self).getImageType()">;
+
+// -----
+
+def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", 
+    [Pure,
+     SPIRV_DimIs<"sampled_image", ["Dim2D", "Cube", "Rect"], SPIRV_SampledImageTransform.result>,
+     SPIRV_MSOperandIs<"sampled_image", ["SingleSampled"], SPIRV_SampledImageTransform.result>,
+     SPIRV_NoneOrElementMatchImage<"result", "sampled_image", SPIRV_SampledImageTransform.result>]>{
   let summary = "Gathers the requested depth-comparison from four texels.";
 
   let description = [{
     Result Type must be a vector of four components of floating-point type
-    or integer type.  Its components must be the same as Sampled Type of the
+    or integer type. Its components must be the same as Sampled Type of the
     underlying OpTypeImage (unless that underlying Sampled Type is
     OpTypeVoid). It has one component per gathered texel.
 
@@ -32,8 +76,8 @@ def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", [Pure]> {
     OpTypeImage must have a Dim of 2D, Cube, or Rect. The MS operand of the
     underlying OpTypeImage must be 0.
 
-    Coordinate  must be a scalar or vector of floating-point type.  It
-    contains (u[, v] … [, array layer]) as needed by the definition of
+    Coordinate must be a scalar or vector of floating-point type. It
+    contains (u[, v] ... [, array layer]) as needed by the definition of
     Sampled Image.
 
     Dref is the depth-comparison reference value. It must be a 32-bit
@@ -44,8 +88,8 @@ def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", [Pure]> {
     #### Example:
 
     ```mlir
-    %0 = spirv.ImageDrefGather %1 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %2 : vector<4xf32>, %3 : f32 -> vector<4xi32>
-    %0 = spirv.ImageDrefGather %1 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %2 : vector<4xf32>, %3 : f32 ["NonPrivateTexel"] : f32, f32 -> vector<4xi32>
+    %0 = spirv.ImageDrefGather %1, %2, %3 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<4xi32>
+    %0 = spirv.ImageDrefGather %1, %2, %3 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 ["NonPrivateTexel"] -> vector<4xi32>
     ```
   }];
 
@@ -57,23 +101,24 @@ def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", [Pure]> {
   ];
 
   let arguments = (ins
-    SPIRV_AnySampledImage:$sampledimage,
+    SPIRV_AnySampledImage:$sampled_image,
     SPIRV_ScalarOrVectorOf<SPIRV_Float>:$coordinate,
-    SPIRV_Float:$dref,
-    OptionalAttr<SPIRV_ImageOperandsAttr>:$imageoperands,
+    SPIRV_Float32:$dref,
+    OptionalAttr<SPIRV_ImageOperandsAttr>:$image_operands,
     Variadic<SPIRV_Type>:$operand_arguments
   );
 
   let results = (outs
-    SPIRV_Vector:$result
+    AnyTypeOf<[SPIRV_Vec4<SPIRV_Integer>, SPIRV_Vec4<SPIRV_Float>]>:$result
   );
 
-  let assemblyFormat = [{$sampledimage `:` type($sampledimage) `,`
-                         $coordinate `:` type($coordinate) `,` $dref `:` type($dref)
-                         custom<ImageOperands>($imageoperands)
-                         ( `(` $operand_arguments^ `:` type($operand_arguments) `)`)?
-                         attr-dict
-                         `->` type($result)}];
+
+  let assemblyFormat = [{
+    $sampled_image `,` $coordinate `,` $dref custom<ImageOperands>($image_operands) ( `(` $operand_arguments^ `)` )? attr-dict 
+    `:` type($sampled_image) `,` type($coordinate) `,` type($dref) ( `(` type($operand_arguments)^ `)` )?
+    `->` type($result) 
+  }];
+
 }
 
 // -----
@@ -82,7 +127,7 @@ def SPIRV_ImageQuerySizeOp : SPIRV_Op<"ImageQuerySize", [Pure]> {
   let summary = "Query the dimensions of Image, with no level of detail.";
 
   let description = [{
-    Result Type must be an integer type scalar or vector.  The number of
+    Result Type must be an integer type scalar or vector. The number of
     components must be:
 
     1 for the 1D and Buffer dimensionalities,
@@ -130,12 +175,15 @@ def SPIRV_ImageQuerySizeOp : SPIRV_Op<"ImageQuerySize", [Pure]> {
     SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$result
   );
 
-  let assemblyFormat = "attr-dict $image `:` type($image) `->` type($result)";
+  let assemblyFormat = "$image attr-dict `:` type($image) `->` type($result)";
 }
 
 // -----
 
-def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite", []> {
+def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite",
+    [SPIRV_SampledOperandIs<"image", ["SamplerUnknown", "NoSampler"]>,
+     SPIRV_DimIsNot<"image", ["SubpassData"]>,
+     SPIRV_NoneOrElementMatchImage<"texel", "image">]> {
   let summary = "Write a texel to an image without a sampler.";
 
   let description = [{
@@ -163,7 +211,7 @@ def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite", []> {
     #### Example:
 
     ```mlir
-    spirv.ImageWrite %0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %1 : vector<2xsi32>, %2 : vector<4xf32>
+    spirv.ImageWrite %0, %1, %2 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, vector<2xsi32>, vector<4xf32>
     ```
   }];
 
@@ -177,20 +225,18 @@ def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite", []> {
 
   let results = (outs);
 
-  let assemblyFormat = [{$image `:` type($image) `,`
-                         $coordinate `:` type($coordinate) `,`
-                         $texel `:` type($texel)
-                         custom<ImageOperands>($image_operands)
-                         ( `(` $operand_arguments^ `:` type($operand_arguments) `)`)?
-                         attr-dict}];
+  let assemblyFormat = [{
+    $image `,` $coordinate `,` $texel custom<ImageOperands>($image_operands) ( `(` $operand_arguments^ `)`)? attr-dict
+    `:` type($image) `,` type($coordinate) `,` type($texel) ( `(` type($operand_arguments)^ `)`)?
+  }];
 }
 
 // -----
 
 def SPIRV_ImageOp : SPIRV_Op<"Image",
     [Pure,
-     TypesMatchWith<"type of 'result' matches image type of 'sampledimage'",
-                    "sampledimage", "result",
+     TypesMatchWith<"type of 'result' matches image type of 'sampled_image'",
+                    "sampled_image", "result",
                     "::llvm::cast<spirv::SampledImageType>($_self).getImageType()">]> {
   let summary = "Extract the image from a sampled image.";
 
@@ -210,14 +256,14 @@ def SPIRV_ImageOp : SPIRV_Op<"Image",
   }];
 
   let arguments = (ins
-    SPIRV_AnySampledImage:$sampledimage
+    SPIRV_AnySampledImage:$sampled_image
   );
 
   let results = (outs
     SPIRV_AnyImage:$result
   );
 
-  let assemblyFormat = "attr-dict $sampledimage `:` type($sampledimage)";
+  let assemblyFormat = "$sampled_image attr-dict `:` type($sampled_image)";
 
   let hasVerifier = 0;
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index ae8ad5a491ff2..235beb0b6a097 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect
   ControlFlowOps.cpp
   CooperativeMatrixOps.cpp
   GroupOps.cpp
+  ImageOps.cpp
   IntegerDotProductOps.cpp
   MemoryOps.cpp
   MeshOps.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp
new file mode 100644
index 0000000000000..d7fcd40c43d88
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp
@@ -0,0 +1,137 @@
+//===- ImageOps.cpp - MLIR SPIR-V Image Ops  ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the image operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Common utility functions
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyImageOperands(Operation *imageOp,
+                                         spirv::ImageOperandsAttr attr,
+                                         Operation::operand_range operands) {
+  if (!attr) {
+    if (operands.empty())
+      return success();
+
+    return imageOp->emitError("the Image Operands should encode what operands "
+                              "follow, as per Image Operands");
+  }
+
+  // TODO: Add the validation rules for the following Image Operands.
+  spirv::ImageOperands noSupportOperands =
+      spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
+      spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
+      spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
+      spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
+      spirv::ImageOperands::MakeTexelAvailable |
+      spirv::ImageOperands::MakeTexelVisible |
+      spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
+
+  assert(!spirv::bitEnumContainsAny(attr.getValue(), noSupportOperands) &&
+         "unimplemented operands of Image Operands");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ImageDrefGather
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::ImageDrefGatherOp::verify() {
+  return verifyImageOperands(getOperation(), getImageOperandsAttr(),
+                             getOperandArguments());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ImageWriteOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::ImageWriteOp::verify() {
+  // TODO: Do we need check for: "If the Arrayed operand is 1, then additional
+  // capabilities may be required; e.g., ImageCubeArray, or ImageMSArray."?
+
+  // TODO: Ideally it should be somewhere verified that "The Image Format must
+  // not be Unknown, unless the StorageImageWriteWithoutFormat Capability was
+  // declared." This function however may not be the suitable place for such
+  // verification.
+
+  return verifyImageOperands(getOperation(), getImageOperandsAttr(),
+                             getOperandArguments());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ImageQuerySize
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::ImageQuerySizeOp::verify() {
+  spirv::ImageType imageType =
+      llvm::cast<spirv::ImageType>(getImage().getType());
+  Type resultType = getResult().getType();
+
+  spirv::Dim dim = imageType.getDim();
+  spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
+  spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
+  switch (dim) {
+  case spirv::Dim::Dim1D:
+  case spirv::Dim::Dim2D:
+  case spirv::Dim::Dim3D:
+  case spirv::Dim::Cube:
+    if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
+        samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
+        samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
+      return emitError(
+          "if Dim is 1D, 2D, 3D, or Cube, "
+          "it must also have either an MS of 1 or a Sampled of 0 or 2");
+    break;
+  case spirv::Dim::Buffer:
+  case spirv::Dim::Rect:
+    break;
+  default:
+    return emitError("the Dim operand of the image type must "
+                     "be 1D, 2D, 3D, Buffer, Cube, or Rect");
+  }
+
+  unsigned componentNumber = 0;
+  switch (dim) {
+  case spirv::Dim::Dim1D:
+  case spirv::Dim::Buffer:
+    componentNumber = 1;
+    break;
+  case spirv::Dim::Dim2D:
+  case spirv::Dim::Cube:
+  case spirv::Dim::Rect:
+    componentNumber = 2;
+    break;
+  case spirv::Dim::Dim3D:
+    componentNumber = 3;
+    break;
+  default:
+    break;
+  }
+
+  if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
+    componentNumber += 1;
+
+  unsigned resultComponentNumber = 1;
+  if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
+    resultComponentNumber = resultVectorType.getNumElements();
+
+  if (componentNumber != resultComponentNumber)
+    return emitError("expected the result to have ")
+           << componentNumber << " component(s), but found "
+           << resultComponentNumber << " component(s)";
+
+  return success();
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index dc414339ae7b8..da9855b02860d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -166,34 +166,6 @@ static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
   p << " : " << resultType;
 }
 
-template <typename Op>
-static LogicalResult verifyImageOperands(Op imageOp,
-                                         spirv::ImageOperandsAttr attr,
-                                         Operation::operand_range operands) {
-  if (!attr) {
-    if (operands.empty())
-      return success();
-
-    return imageOp.emitError("the Image Operands should encode what operands "
-                             "follow, as per Image Operands");
-  }
-
-  // TODO: Add the validation rules for the following Image Operands.
-  spirv::ImageOperands noSupportOperands =
-      spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
-      spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
-      spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
-      spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
-      spirv::ImageOperands::MakeTexelAvailable |
-      spirv::ImageOperands::MakeTexelVisible |
-      spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
-
-  if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
-    llvm_unreachable("unimplemented operands of Image Operands");
-
-  return success();
-}
-
 template <typename BlockReadWriteOpTy>
 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
                                                         Value ptr, Value val) {
@@ -2002,85 +1974,6 @@ LogicalResult spirv::GLLdexpOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.ImageDrefGather
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::ImageDrefGatherOp::verify() {
-  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
-  auto sampledImageType =
-      llvm::cast<spirv::SampledImageType>(getSampledimage().getType());
-  auto imageType =
-      llvm::cast<spirv::ImageType>(sampledImageType.getImageType());
-
-  if (resultType.getNumElements() != 4)
-    return emitOpError("result type must be a vector of four components");
-
-  Type elementType = resultType.getElementType();
-  Type sampledElementType = imageType.getElementType();
-  if (!llvm::isa<NoneType>(sampledElementType) &&
-      elementType != sampledElementType)
-    return emitOpError(
-        "the component type of result must be the same as sampled type of the "
-        "underlying image type");
-
-  spirv::Dim imageDim = imageType.getDim();
-  spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
-
-  if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
-      imageDim != spirv::Dim::Rect)
-    return emitOpError(
-        "the Dim operand of the underlying image type must be 2D, Cube, or "
-        "Rect");
-
-  if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
-    return emitOpError("the MS operand of the underlying image type must be 0");
-
-  spirv::ImageOperandsAttr attr = getImageoperandsAttr();
-  auto operandArguments = getOperandArguments();
-
-  return verifyImageOperands(*this, attr, operandArguments);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.ImageWriteOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::ImageWriteOp::verify() {
-  ImageType imageType = cast<ImageType>(getImage().getType());
-  Type sampledType = imageType.getElementType();
-  ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
-
-  if (!llvm::is_contained({spirv::ImageSamplerUseInfo::SamplerUnknown,
-                           spirv::ImageSamplerUseInfo::NoSampler},
-                          samplerInfo)) {
-    return emitOpError(
-        "the sampled operand of the underlying image must be 0 or 2");
-  }
-
-  // TODO: Do we need check for: "If the Arrayed operand is 1, then additional
-  // capabilities may be required; e.g., ImageCubeArray, or ImageMSArray."?
-
-  if (imageType.getDim() == spirv::Dim::SubpassData) {
-    return emitOpError(
-        "the Dim operand of the underlying image must not be SubpassData");
-  }
-
-  Type texelType = getElementTypeOrSelf(getTexel());
-  if (!isa<NoneType>(sampledType) && texelType != sampledType) {
-    return emitOpError(
-        "the texel component type must match the image sampled type");
-  }
-
-  // TODO: Ideally it should be somewhere verified that "The Image Format must
-  // not be Unknown, unless the StorageImageWriteWithoutFormat Capability was
-  // declared." This function however may not be the suitable place for such
-  // verification.
-
-  return verifyImageOperands(*this, getImageOperandsAttr(),
-                             getOperandArguments());
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.ShiftLeftLogicalOp
 //===----------------------------------------------------------------------===//
@@ -2105,71 +1998,6 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
   return verifyShiftOp(*this);
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.ImageQuerySize
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::ImageQuerySizeOp::verify() {
-  spirv::ImageType imageType =
-      llvm::cast<spirv::ImageType>(getImage().getType());
-  Type resultType = getResult().getType();
-
-  spirv::Dim dim = imageType.getDim();
-  spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
-  spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
-  switch (dim) {
-  case spirv::Dim::Dim1D:
-  case spirv::Dim::Dim2D:
-  case spirv::Dim::Dim3D:
-  case spirv::Dim::Cube:
-    if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
-        samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
-        samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
-      return emitError(
-          "if Dim is 1D, 2D, 3D, or Cube, "
-          "it must also have either an MS of 1 or a Sampled of 0 or 2");
-    break;
-  case spirv::Dim::Buffer:
-  case spirv::Dim::Rect:
-    break;
-  default:
-    return emitError("the Dim operand of the image type must "
-                     "be 1D, 2D, 3D, Buffer, Cube, or Rect");
-  }
-
-  unsigned componentNumber = 0;
-  switch (dim) {
-  case spirv::Dim::Dim1D:
-  case spirv::Dim::Buffer:
-    componentNumber = 1;
-    break;
-  case spirv::Dim::Dim2D:
-  case spirv::Dim::Cube:
-  case spirv::Dim::Rect:
-    componentNumber = 2;
-    break;
-  case spirv::Dim::Dim3D:
-    componentNumber = 3;
-    break;
-  default:
-    break;
-  }
-
-  if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
-    componentNumber += 1;
-
-  unsigned resultComponentNumber = 1;
-  if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
-    resultComponentNumber = resultVectorType.getNumElements();
-
-  if (componentNumber != resultComponentNumber)
-    return emitError("expected the result to have ")
-           << componentNumber << " component(s), but found "
-           << resultComponentNumber << " component(s)";
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.VectorTimesScalarOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
index 1161f85563ae6..266b69fd117d8 100644
--- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
@@ -1,20 +1,20 @@
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // spirv.ImageDrefGather
 //===----------------------------------------------------------------------===//
 
 func.func @image_dref_gather(%arg0 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32) -> () {
-  // CHECK: spirv.ImageDrefGather {{.*}} : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, {{.*}} : vector<4xf32>, {{.*}} : f32 -> vector<4xi32>
-  %0 = spirv.ImageDrefGather %arg0 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<4xi32>
+  // CHECK: spirv.ImageDrefGather {{.*}}, {{.*}}, {{.*}} : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<4xi32>
+  %0 = spirv.ImageDrefGather %arg0, %arg1, %arg2 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<4xi32>
   spirv.Return
 }
 
 // -----
 
 func.func @image_dref_gather_with_single_imageoperands(%arg0 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32) -> () {
-  // CHECK: spirv.ImageDrefGather {{.*}} ["NonPrivateTexel"] -> vector<4xi32>
-  %0 = spirv.ImageDrefGather %arg0 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32 ["NonPrivateTexel"] -> vector<4xi32>
+  // CHECK: spirv.ImageDrefGather {{.*}}, {{.*}}, {{.*}} ["NonPrivateTexel"] : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<4xi32>
+  %0 = spirv.ImageDrefGather %arg0, %arg1, %arg2 ["NonPrivateTexel"] : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<4xi32>
   spirv.Return
 }
 
@@ -22,39 +22,39 @@ func.func @image_dref_gather_with_single_imageoperands(%arg0 : !spirv.sampled_im
 
 func.func @image_dref_gather_with_mismatch_imageoperands(%arg0 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32) -> () {
   // expected-error @+1 {{the Image Operands should encode what operands follow, as per Image Operands}}
-  %0 = spirv.ImageDrefGather %arg0 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32 (%arg2, %arg2 : f32, f32) -> vector<4xi32>
+  %0 = spirv.ImageDrefGather %arg0, %arg1, %arg2 (%arg2, %arg2) : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 (f32, f32) -> vector<4xi32>
   spirv.Return
 }
 
 // -----
 
 func.func @image_dref_gather_error_result_type(%arg0 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32) -> () {
-  // expected-error @+1 {{result type must be a vector of four components}}
-  %0 = spirv.ImageDrefGather %arg0 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<3xi32>
+  // expected-error @+1 {{must be vector of 8/16/32/64-bit integer values of length 4 or vector of 16/32/64-bit float values of length 4}}
+  %0 = spirv.ImageDrefGather %arg0, %arg1, %arg2 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<3xi32>
   spirv.Return
 }
 
 // -----
 
 func.func @image_dref_gather_error_same_type(%arg0 : !spirv.sampled_image<!spirv.image<i32, Rect, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32) -> () {
-  // expected-error @+1 {{the component type of result must be the same as sampled type of the underlying image type}}
-  %0 = spirv.ImageDrefGather %arg0 : !spirv.sampled_image<!spirv.image<i32, Rect, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<4xf32>
+  // expected-error @+1 {{the result component type must match the image sampled type}}
+  %0 = spirv.ImageDrefGather %arg0, %arg1, %arg2 : !spirv.sampled_image<!spirv.image<i32, Rect, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<4xf32>
   spirv.Return
 }
 
 // -----
 
 func.func @image_dref_gather_error_dim(%arg0 : !spirv.sampled_image<!spirv.image<i32, Dim1D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32) -> () {
-  // expected-error @+1 {{the Dim operand of the underlying image type must be 2D, Cube, or Rect}}
-  %0 = spirv.ImageDrefGather %arg0 : !spirv.sampled_image<!spirv.image<i32, Dim1D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<4xi32>
+  // expected-error @+1 {{the Dim operand of the underlying image must be Dim2D or Cube or Rect}}
+  %0 = spirv.ImageDrefGather %arg0, %arg1, %arg2 : !spirv.sampled_image<!spirv.image<i32, Dim1D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<4xi32>
   spirv.Return
 }
 
 // -----
 
 func.func @image_dref_gather_error_ms(%arg0 : !spirv.sampled_image<!spirv.image<i32, Cube, NoDepth, NonArrayed, MultiSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32) -> () {
-  // expected-error @+1 {{the MS operand of the underlying image type must be 0}}
-  %0 = spirv.ImageDrefGather %arg0 : !spirv.sampled_image<!spirv.image<i32, Cube, NoDepth, NonArrayed, MultiSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<4xi32>
+  // expected-error @+1 {{the MS operand of the underlying image type must be SingleSampled}}
+  %0 = spirv.ImageDrefGather %arg0, %arg1, %arg2 : !spirv.sampled_image<!spirv.image<i32, Cube, NoDepth, NonArrayed, MultiSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<4xi32>
   spirv.Return
 }
 
@@ -121,24 +121,24 @@ func.func @image_query_size_error_result2(%arg0 : !spirv.image<f32, Buffer, NoDe
 //===----------------------------------------------------------------------===//
 
 func.func @image_write(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>) -> () {
-  // CHECK:  spirv.ImageWrite %arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>
-  spirv.ImageWrite %arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>
+  // CHECK:  spirv.ImageWrite {{%.*}}, {{%.*}}, {{%.*}} : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, vector<2xsi32>, vector<4xf32>
+  spirv.ImageWrite %arg0, %arg1, %arg2 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, vector<2xsi32>, vector<4xf32>
   spirv.Return
 }
 
 // -----
 
 func.func @image_write_scalar_texel(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : f32) -> () {
-  // CHECK:  spirv.ImageWrite %arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : f32
-  spirv.ImageWrite %arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : f32
+  // CHECK:  spirv.ImageWrite {{%.*}}, {{%.*}}, {{%.*}} : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, vector<2xsi32>, f32
+  spirv.ImageWrite %arg0, %arg1, %arg2 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, vector<2xsi32>, f32
   spirv.Return
 }
 
 // -----
 
 func.func @image_write_need_sampler(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>) -> () {
-  // expected-error @+1 {{the sampled operand of the underlying image must be 0 or 2}}
-  spirv.ImageWrite %arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>
+  // expected-error @+1 {{the sampled operand of the underlying image must be SamplerUnknown or NoSampler}}
+  spirv.ImageWrite %arg0, %arg1, %arg2 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba16>, vector<2xsi32>, vector<4xf32>
   spirv.Return
 }
 
@@ -146,7 +146,7 @@ func.func @image_write_need_sampler(%arg0 : !spirv.image<f32, Dim2D, NoDepth, No
 
 func.func @image_write_subpass_data(%arg0 : !spirv.image<f32, SubpassData, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>) -> () {
   // expected-error @+1 {{the Dim operand of the underlying image must not be SubpassData}}
-  spirv.ImageWrite %arg0 : !spirv.image<f32, SubpassData, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>
+  spirv.ImageWrite %arg0, %arg1, %arg2 : !spirv.image<f32, SubpassData, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, vector<2xsi32>, vector<4xf32>
   spirv.Return
 }
 
@@ -154,6 +154,6 @@ func.func @image_write_subpass_data(%arg0 : !spirv.image<f32, SubpassData, NoDep
 
 func.func @image_write_texel_type_mismatch(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : vector<4xi32>) -> () {
   // expected-error @+1 {{the texel component type must match the image sampled type}}
-  spirv.ImageWrite %arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %arg1 : vector<2xsi32>, %arg2 : vector<4xi32>
+  spirv.ImageWrite %arg0, %arg1, %arg2 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, vector<2xsi32>, vector<4xi32>
   spirv.Return
 }
diff --git a/mlir/test/Target/SPIRV/image-ops.mlir b/mlir/test/Target/SPIRV/image-ops.mlir
index 6b52a84ba82f7..3c28c3f71bc2a 100644
--- a/mlir/test/Target/SPIRV/image-ops.mlir
+++ b/mlir/test/Target/SPIRV/image-ops.mlir
@@ -1,11 +1,11 @@
 // RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
 
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ImageQuery], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ImageQuery, Linkage], []> {
   spirv.func @image(%arg0 : !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32) "None" {
     // CHECK: {{%.*}} = spirv.Image {{%.*}} : !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
     %0 = spirv.Image %arg0 : !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>
-    // CHECK: {{%.*}} = spirv.ImageDrefGather {{%.*}} : !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>,  {{%.*}} : vector<4xf32>,  {{%.*}} : f32 -> vector<4xf32>
-    %1 = spirv.ImageDrefGather %arg0 : !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32 -> vector<4xf32>
+    // CHECK: {{%.*}} = spirv.ImageDrefGather {{%.*}}, {{%.*}}, {{%.*}} : !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>,  vector<4xf32>, f32 -> vector<4xf32>
+    %1 = spirv.ImageDrefGather %arg0, %arg1, %arg2 : !spirv.sampled_image<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>>, vector<4xf32>, f32 -> vector<4xf32>
     spirv.Return
   }
   spirv.func @image_query_size(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>) "None" {
@@ -14,26 +14,18 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ImageQuery], []>
     spirv.Return
   }
   spirv.func @image_write(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba8>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>) "None" {
-    // CHECK:  spirv.ImageWrite %arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba8>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>
-    spirv.ImageWrite %arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba8>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>
+    // CHECK: spirv.ImageWrite {{%.*}}, {{%.*}}, {{%.*}} : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba8>, vector<2xsi32>, vector<4xf32>
+    spirv.ImageWrite %arg0, %arg1, %arg2 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba8>, vector<2xsi32>, vector<4xf32>
     spirv.Return
   }
-  spirv.func @main() "None" {
-    spirv.Return
-  }
-  spirv.EntryPoint "GLCompute" @main
 }
 
 // -----
 
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, StorageImageWriteWithoutFormat], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, StorageImageWriteWithoutFormat, Linkage], []> {
   spirv.func @image_write(%arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>) "None" {
-    // CHECK:  spirv.ImageWrite %arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>
-    spirv.ImageWrite %arg0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, %arg1 : vector<2xsi32>, %arg2 : vector<4xf32>
-    spirv.Return
-  }
-  spirv.func @main() "None" {
+    // CHECK: spirv.ImageWrite {{%.*}}, {{%.*}}, {{%.*}} : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, vector<2xsi32>, vector<4xf32>
+    spirv.ImageWrite %arg0, %arg1, %arg2 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, vector<2xsi32>, vector<4xf32>
     spirv.Return
   }
-  spirv.EntryPoint "GLCompute" @main
 }



More information about the Mlir-commits mailing list