[Mlir-commits] [mlir] [mlir][gpu] Add gpu.rotate operation (PR #142796)

Hsiangkai Wang llvmlistbot at llvm.org
Fri Jun 27 03:51:08 PDT 2025


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/142796

>From 855a0f69f9e6f777cb4a049f4d8bd7a94b67fcdf Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 4 Jun 2025 15:31:38 +0100
Subject: [PATCH 1/6] [mlir][gpu] Add gpu.rotate operation

Add gpu.rotate operation and a pattern to convert gpu.rotate to SPIR-V
OpGroupNonUniformRotateKHR.
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    | 43 ++++++++++
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 37 ++++++++-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 45 +++++++++++
 mlir/test/Conversion/GPUToSPIRV/rotate.mlir   | 26 +++++++
 mlir/test/Dialect/GPU/invalid.mlir            | 78 +++++++++++++++++++
 mlir/test/Dialect/GPU/ops.mlir                |  4 +
 6 files changed, 232 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Conversion/GPUToSPIRV/rotate.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 15b14c767b66a..46bd6039657bd 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1364,6 +1364,49 @@ def GPU_ShuffleOp : GPU_Op<
   ];
 }
 
+def GPU_RotateOp : GPU_Op<
+    "rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
+    Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
+    Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult)> {
+  let summary = "Rotate values within a subgroup.";
+  let description = [{
+    The "rotate" op moves values across lanes (a.k.a., invocations, work items)
+    within the same subgroup. The `width` argument specifies the number of lanes
+    that participate in the rotation, and must be uniform across all lanes.
+    Further, the first `width` lanes of the subgroup must be active.
+
+    `width` must be a power of two, and `offset` must be in the range
+    `[0, width)`.
+
+    Return the `rotateResult` of the invocation whose id within the group is
+    calculated as follows:
+
+    Invocation ID = ((LocalId + Delta) & (width - 1)) + (LocalId & ~(width - 1))
+
+    Returns the `rotateResult` if the current lane id is smaller than `width`.
+
+    example:
+
+    ```mlir
+    %cst1 = arith.constant 1 : i32
+    %1 = gpu.rotate %0, %cst1, %width : f32
+    ```
+
+    For lane `k`, returns the value from lane `(k + cst1) % width`.
+  }];
+
+  let assemblyFormat = [{
+    $value `,` $offset `,` $width attr-dict `:` type($value)
+  }];
+
+  let builders = [
+    // Helper function that creates a rotate with constant offset/width.
+    OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)>
+  ];
+
+  let hasVerifier = 1;
+}
+
 def GPU_BarrierOp : GPU_Op<"barrier"> {
   let summary = "Synchronizes all work items of a workgroup.";
   let description = [{
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 78e6ebb523a46..546705244f35c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -122,6 +122,16 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHROp.
+class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -458,6 +468,31 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Rotate
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPURotateConversion::matchAndRewrite(
+    gpu::RotateOp rotateOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
+  unsigned subgroupSize =
+      targetEnv.getAttr().getResourceLimits().getSubgroupSize();
+  IntegerAttr widthAttr;
+  if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
+      widthAttr.getValue().getZExtValue() > subgroupSize)
+    return rewriter.notifyMatchFailure(
+        rotateOp, "rotate width is larger than target subgroup size");
+
+  Location loc = rotateOp.getLoc();
+  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+  Value result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
+      loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
+
+  rewriter.replaceOp(rotateOp, result);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Group ops
 //===----------------------------------------------------------------------===//
@@ -733,7 +768,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                                       RewritePatternSet &patterns) {
   patterns.add<
       GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
-      GPUReturnOpConversion, GPUShuffleConversion,
+      GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 39f626b558294..a9a9473a1c333 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1331,6 +1331,51 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
         mode);
 }
 
+//===----------------------------------------------------------------------===//
+// RotateOp
+//===----------------------------------------------------------------------===//
+
+void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
+                     int32_t offset, int32_t width) {
+  build(builder, result, value,
+        builder.create<arith::ConstantOp>(result.location,
+                                          builder.getI32IntegerAttr(offset)),
+        builder.create<arith::ConstantOp>(result.location,
+                                          builder.getI32IntegerAttr(width)));
+}
+
+LogicalResult RotateOp::verify() {
+  llvm::APInt offsetValue;
+  if (auto constOp = getOffset().getDefiningOp<arith::ConstantOp>()) {
+    if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
+      offsetValue = intAttr.getValue();
+    } else {
+      return emitOpError() << "offset is not an integer value";
+    }
+  } else {
+    return emitOpError() << "offset is not a constant value";
+  }
+
+  llvm::APInt widthValue;
+  if (auto constOp = getWidth().getDefiningOp<arith::ConstantOp>()) {
+    if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
+      widthValue = intAttr.getValue();
+    } else {
+      return emitOpError() << "width is not an integer value";
+    }
+  } else {
+    return emitOpError() << "width is not a constant value";
+  }
+
+  if (!widthValue.isPowerOf2())
+    return emitOpError() << "width must be a power of two";
+
+  if (offsetValue.sge(widthValue) || offsetValue.slt(0))
+    return emitOpError() << "offset must be in the range [0, width)";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // BarrierOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
new file mode 100644
index 0000000000000..e0dd14d87d42f
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
+    #spirv.resource_limits<subgroup_size = 16>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @rotate()
+  gpu.func @rotate() kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    %offset = arith.constant 4 : i32
+    %width = arith.constant 16 : i32
+    %val = arith.constant 42.0 : f32
+
+    // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
+    // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
+    // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
+    // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
+    %result = gpu.rotate %val, %offset, %width : f32
+    gpu.return
+  }
+}
+
+}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index ce1be7b5618fe..0ad5690f5cf70 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -478,6 +478,84 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a
 
 // -----
 
+func.func @rotate_mismatching_type(%arg0 : f32) {
+  %offset = arith.constant 4 : i32
+  %width = arith.constant 16 : i32
+  // expected-error at +1 {{op failed to verify that all of {value, rotateResult} have same type}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> i32
+  return
+}
+
+// -----
+
+func.func @rotate_unsupported_type(%arg0 : index) {
+  %offset = arith.constant 4 : i32
+  %width = arith.constant 16 : i32
+  // expected-error at +1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
+  %shfl = gpu.rotate %arg0, %offset, %width : index
+  return
+}
+
+// -----
+
+func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) {
+  %offset = arith.constant 4 : i32
+  %width = arith.constant 16 : i32
+  // expected-error at +1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
+  %shfl = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
+  return
+}
+
+// -----
+
+func.func @rotate_unsupported_width(%arg0 : f32) {
+  %offset = arith.constant 4 : i32
+  %width = arith.constant 15 : i32
+  // expected-error at +1 {{op width must be a power of two}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  return
+}
+
+// -----
+
+func.func @rotate_unsupported_offset(%arg0 : f32) {
+  %offset = arith.constant 16 : i32
+  %width = arith.constant 16 : i32
+  // expected-error at +1 {{op offset must be in the range [0, width)}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  return
+}
+
+// -----
+
+func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
+  %offset = arith.constant -1 : i32
+  %width = arith.constant 16 : i32
+  // expected-error at +1 {{op offset must be in the range [0, width)}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  return
+}
+
+// -----
+
+func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
+  %width = arith.constant 16 : i32
+  // expected-error at +1 {{op offset is not a constant value}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  return
+}
+
+// -----
+
+func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) {
+  %offset = arith.constant 0 : i32
+  // expected-error at +1 {{op width is not a constant value}}
+  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  return
+}
+
+// -----
+
 module {
   gpu.module @gpu_funcs {
     // expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 9dbe16774f517..4beb8ffa09ac6 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -140,6 +140,10 @@ module attributes {gpu.container_module} {
       // CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32
       %shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32
 
+      // CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32
+      %rotate_width = arith.constant 16 : i32
+      %rotate = gpu.rotate %arg0, %offset, %rotate_width : f32
+
       "gpu.barrier"() : () -> ()
 
       "some_op"(%bIdX, %tIdX) : (index, index) -> ()

>From 6b80ff7347e94ba8d8c2aae55c5ae62a49157155 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Tue, 17 Jun 2025 11:12:10 +0100
Subject: [PATCH 2/6] Address comments

---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    | 12 ++++---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp |  6 +++-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 35 +++++++++----------
 mlir/test/Conversion/GPUToSPIRV/rotate.mlir   |  5 ++-
 mlir/test/Dialect/GPU/invalid.mlir            | 16 ++++-----
 mlir/test/Dialect/GPU/ops.mlir                |  2 +-
 6 files changed, 43 insertions(+), 33 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 46bd6039657bd..1a21614a52236 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1367,7 +1367,7 @@ def GPU_ShuffleOp : GPU_Op<
 def GPU_RotateOp : GPU_Op<
     "rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
     Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
-    Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult)> {
+    Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult, I1:$valid)> {
   let summary = "Rotate values within a subgroup.";
   let description = [{
     The "rotate" op moves values across lanes (a.k.a., invocations, work items)
@@ -1381,15 +1381,19 @@ def GPU_RotateOp : GPU_Op<
     Return the `rotateResult` of the invocation whose id within the group is
     calculated as follows:
 
-    Invocation ID = ((LocalId + Delta) & (width - 1)) + (LocalId & ~(width - 1))
+    ```mlir
+    Invocation ID = ((LaneId + offset) & (width - 1)) + (LaneId & ~(width - 1))
+    ```
 
-    Returns the `rotateResult` if the current lane id is smaller than `width`.
+    Returns the `rotateResult` and `true` if the current lane id is smaller than
+    `width`, and an unspecified value and `false` otherwise.
 
     example:
 
     ```mlir
     %cst1 = arith.constant 1 : i32
-    %1 = gpu.rotate %0, %cst1, %width : f32
+    %width = arith.constant 16 : i32
+    %1, %2 = gpu.rotate %0, %cst1, %width : f32
     ```
 
     For lane `k`, returns the value from lane `(k + cst1) % width`.
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 546705244f35c..d46d0563be057 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -489,7 +489,11 @@ LogicalResult GPURotateConversion::matchAndRewrite(
   Value result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
       loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
 
-  rewriter.replaceOp(rotateOp, result);
+  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
+  Value validVal = rewriter.create<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::ult, laneId, adaptor.getWidth());
+
+  rewriter.replaceOp(rotateOp, {result, validVal});
   return success();
 }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index a9a9473a1c333..a72207c7cebf9 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1345,27 +1345,26 @@ void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
 }
 
 LogicalResult RotateOp::verify() {
-  llvm::APInt offsetValue;
-  if (auto constOp = getOffset().getDefiningOp<arith::ConstantOp>()) {
-    if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
-      offsetValue = intAttr.getValue();
-    } else {
-      return emitOpError() << "offset is not an integer value";
-    }
-  } else {
+  auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>();
+  if (!offsetConstOp)
     return emitOpError() << "offset is not a constant value";
-  }
 
-  llvm::APInt widthValue;
-  if (auto constOp = getWidth().getDefiningOp<arith::ConstantOp>()) {
-    if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
-      widthValue = intAttr.getValue();
-    } else {
-      return emitOpError() << "width is not an integer value";
-    }
-  } else {
+  auto offsetIntAttr =
+      llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue());
+  if (!offsetIntAttr)
+    return emitOpError() << "offset is not an integer value";
+
+  auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>();
+  if (!widthConstOp)
     return emitOpError() << "width is not a constant value";
-  }
+
+  auto widthIntAttr =
+      llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue());
+  if (!widthIntAttr)
+    return emitOpError() << "width is not an integer value";
+
+  llvm::APInt offsetValue = offsetIntAttr.getValue();
+  llvm::APInt widthValue = widthIntAttr.getValue();
 
   if (!widthValue.isPowerOf2())
     return emitOpError() << "width must be a power of two";
diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
index e0dd14d87d42f..513377c2de697 100644
--- a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
@@ -18,7 +18,10 @@ gpu.module @kernels {
     // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
     // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
     // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
-    %result = gpu.rotate %val, %offset, %width : f32
+    // CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__
+    // CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]]
+    // CHECK: %[[VALID:.+]] = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]]
+    %result, %valid = gpu.rotate %val, %offset, %width : f32
     gpu.return
   }
 }
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 0ad5690f5cf70..b4fe3a51e7973 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -482,7 +482,7 @@ func.func @rotate_mismatching_type(%arg0 : f32) {
   %offset = arith.constant 4 : i32
   %width = arith.constant 16 : i32
   // expected-error at +1 {{op failed to verify that all of {value, rotateResult} have same type}}
-  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> i32
+  %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1)
   return
 }
 
@@ -492,7 +492,7 @@ func.func @rotate_unsupported_type(%arg0 : index) {
   %offset = arith.constant 4 : i32
   %width = arith.constant 16 : i32
   // expected-error at +1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
-  %shfl = gpu.rotate %arg0, %offset, %width : index
+  %rotate, %valid = gpu.rotate %arg0, %offset, %width : index
   return
 }
 
@@ -502,7 +502,7 @@ func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) {
   %offset = arith.constant 4 : i32
   %width = arith.constant 16 : i32
   // expected-error at +1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
-  %shfl = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
+  %rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
   return
 }
 
@@ -512,7 +512,7 @@ func.func @rotate_unsupported_width(%arg0 : f32) {
   %offset = arith.constant 4 : i32
   %width = arith.constant 15 : i32
   // expected-error at +1 {{op width must be a power of two}}
-  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
   return
 }
 
@@ -522,7 +522,7 @@ func.func @rotate_unsupported_offset(%arg0 : f32) {
   %offset = arith.constant 16 : i32
   %width = arith.constant 16 : i32
   // expected-error at +1 {{op offset must be in the range [0, width)}}
-  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
   return
 }
 
@@ -532,7 +532,7 @@ func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
   %offset = arith.constant -1 : i32
   %width = arith.constant 16 : i32
   // expected-error at +1 {{op offset must be in the range [0, width)}}
-  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
   return
 }
 
@@ -541,7 +541,7 @@ func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
 func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
   %width = arith.constant 16 : i32
   // expected-error at +1 {{op offset is not a constant value}}
-  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
   return
 }
 
@@ -550,7 +550,7 @@ func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
 func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) {
   %offset = arith.constant 0 : i32
   // expected-error at +1 {{op width is not a constant value}}
-  %shfl = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> f32
+  %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
   return
 }
 
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 4beb8ffa09ac6..2aef80f73feb3 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -142,7 +142,7 @@ module attributes {gpu.container_module} {
 
       // CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32
       %rotate_width = arith.constant 16 : i32
-      %rotate = gpu.rotate %arg0, %offset, %rotate_width : f32
+      %rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32
 
       "gpu.barrier"() : () -> ()
 

>From 55c0c289adfde61ec0460651f6975c3fb137a19e Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 18 Jun 2025 21:08:30 +0100
Subject: [PATCH 3/6] update description and remove redundant checkings

---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td  | 17 +++++++++--------
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp      |  4 ----
 mlir/test/Conversion/GPUToSPIRV/rotate.mlir |  2 +-
 3 files changed, 10 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 1a21614a52236..ac8f37bc6bd3c 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1304,8 +1304,8 @@ def GPU_ShuffleOp : GPU_Op<
     Results<(outs AnyIntegerOrFloatOr1DVector:$shuffleResult, I1:$valid)> {
   let summary = "Shuffles values within a subgroup.";
   let description = [{
-    The "shuffle" op moves values to a across lanes (a.k.a., invocations,
-    work items) within the same subgroup. The `width` argument specifies the
+    The "shuffle" op moves values across lanes in a subgroup (a.k.a., local
+    invocation) within the same subgroup. The `width` argument specifies the
     number of lanes that participate in the shuffle, and must be uniform
     across all lanes. Further, the first `width` lanes of the subgroup must
     be active.
@@ -1370,10 +1370,11 @@ def GPU_RotateOp : GPU_Op<
     Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult, I1:$valid)> {
   let summary = "Rotate values within a subgroup.";
   let description = [{
-    The "rotate" op moves values across lanes (a.k.a., invocations, work items)
-    within the same subgroup. The `width` argument specifies the number of lanes
-    that participate in the rotation, and must be uniform across all lanes.
-    Further, the first `width` lanes of the subgroup must be active.
+    The "rotate" op moves values across lanes in a subgroup (a.k.a., local
+    invocations) within the same subgroup. The `width` argument specifies the
+    number of lanes that participate in the rotation, and must be uniform across
+    all participating lanes. Further, the first `width` lanes of the subgroup
+    must be active.
 
     `width` must be a power of two, and `offset` must be in the range
     `[0, width)`.
@@ -1391,9 +1392,9 @@ def GPU_RotateOp : GPU_Op<
     example:
 
     ```mlir
-    %cst1 = arith.constant 1 : i32
+    %offset = arith.constant 1 : i32
     %width = arith.constant 16 : i32
-    %1, %2 = gpu.rotate %0, %cst1, %width : f32
+    %1, %2 = gpu.rotate %0, %offset, %width : f32
     ```
 
     For lane `k`, returns the value from lane `(k + cst1) % width`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index a72207c7cebf9..8eafe84a02720 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1351,8 +1351,6 @@ LogicalResult RotateOp::verify() {
 
   auto offsetIntAttr =
       llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue());
-  if (!offsetIntAttr)
-    return emitOpError() << "offset is not an integer value";
 
   auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>();
   if (!widthConstOp)
@@ -1360,8 +1358,6 @@ LogicalResult RotateOp::verify() {
 
   auto widthIntAttr =
       llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue());
-  if (!widthIntAttr)
-    return emitOpError() << "width is not an integer value";
 
   llvm::APInt offsetValue = offsetIntAttr.getValue();
   llvm::APInt widthValue = widthIntAttr.getValue();
diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
index 513377c2de697..5677e3315750f 100644
--- a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv %s -o - | FileCheck %s
 
 module attributes {
   gpu.container_module,

>From f91fcb86ebd03b44521e6a1314c09cb8e33f708e Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 19 Jun 2025 09:29:32 +0100
Subject: [PATCH 4/6] return poison value if lane id is out of rotation range

---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  2 +-
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 20 +++++++++----
 mlir/test/Conversion/GPUToSPIRV/rotate.mlir   | 30 +++++++++++++++++++
 3 files changed, 46 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index ac8f37bc6bd3c..8893733fac8ca 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1387,7 +1387,7 @@ def GPU_RotateOp : GPU_Op<
     ```
 
     Returns the `rotateResult` and `true` if the current lane id is smaller than
-    `width`, and an unspecified value and `false` otherwise.
+    `width`, and poison value and `false` otherwise.
 
     example:
 
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index d46d0563be057..3c4f1450abbd8 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -486,12 +486,22 @@ LogicalResult GPURotateConversion::matchAndRewrite(
 
   Location loc = rotateOp.getLoc();
   auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
-  Value result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
+  Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
       loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
-
-  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
-  Value validVal = rewriter.create<arith::CmpIOp>(
-      loc, arith::CmpIPredicate::ult, laneId, adaptor.getWidth());
+  Value result;
+  Value validVal;
+  if (widthAttr.getValue().getZExtValue() == subgroupSize) {
+    result = rotateResult;
+    validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
+  } else {
+    Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
+    validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+                                              laneId, adaptor.getWidth());
+    Value undefVal =
+        rewriter.create<spirv::UndefOp>(loc, rotateResult.getType());
+    result =
+        rewriter.create<spirv::SelectOp>(loc, validVal, rotateResult, undefVal);
+  }
 
   rewriter.replaceOp(rotateOp, {result, validVal});
   return success();
diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
index 5677e3315750f..9e4a4e5886a9c 100644
--- a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
@@ -18,9 +18,39 @@ gpu.module @kernels {
     // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
     // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
     // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
+    // CHECK: %{{.+}} = spirv.Constant true
+    %result, %valid = gpu.rotate %val, %offset, %width : f32
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
+    #spirv.resource_limits<subgroup_size = 16>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @rotate_width_less_than_subgroup_size()
+  gpu.func @rotate_width_less_than_subgroup_size() kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    %offset = arith.constant 4 : i32
+    %width = arith.constant 8 : i32
+    %val = arith.constant 42.0 : f32
+
+    // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
+    // CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32
+    // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
+    // CHECK: %[[ROTATE_VAL:.+]] = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
     // CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__
     // CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]]
     // CHECK: %[[VALID:.+]] = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]]
+    // CHECK: %[[UNDEF:.+]] = spirv.Undef : f32
+    // CHECK: %[[RESULT:.+]] = spirv.Select %[[VALID]], %[[ROTATE_VAL]], %[[UNDEF]] : i1, f32
     %result, %valid = gpu.rotate %val, %offset, %width : f32
     gpu.return
   }

>From 01bbc43f5a7476aebe22c6f786ae7baddbb857e5 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 19 Jun 2025 16:48:09 +0100
Subject: [PATCH 5/6] Remove redundant constant

---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 8 +-------
 mlir/test/Conversion/GPUToSPIRV/rotate.mlir   | 6 ++----
 2 files changed, 3 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 3c4f1450abbd8..8b2309cc0fce5 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -488,22 +488,16 @@ LogicalResult GPURotateConversion::matchAndRewrite(
   auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
   Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
       loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
-  Value result;
   Value validVal;
   if (widthAttr.getValue().getZExtValue() == subgroupSize) {
-    result = rotateResult;
     validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
   } else {
     Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
     validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
                                               laneId, adaptor.getWidth());
-    Value undefVal =
-        rewriter.create<spirv::UndefOp>(loc, rotateResult.getType());
-    result =
-        rewriter.create<spirv::SelectOp>(loc, validVal, rotateResult, undefVal);
   }
 
-  rewriter.replaceOp(rotateOp, {result, validVal});
+  rewriter.replaceOp(rotateOp, {rotateResult, validVal});
   return success();
 }
 
diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
index 9e4a4e5886a9c..b0c7005e1510d 100644
--- a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
@@ -45,12 +45,10 @@ gpu.module @kernels {
     // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
     // CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32
     // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
-    // CHECK: %[[ROTATE_VAL:.+]] = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
+    // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
     // CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__
     // CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]]
-    // CHECK: %[[VALID:.+]] = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]]
-    // CHECK: %[[UNDEF:.+]] = spirv.Undef : f32
-    // CHECK: %[[RESULT:.+]] = spirv.Select %[[VALID]], %[[ROTATE_VAL]], %[[UNDEF]] : i1, f32
+    // CHECK: %{{.+}} = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]]
     %result, %valid = gpu.rotate %val, %offset, %width : f32
     gpu.return
   }

>From 58058c893a01d373c7d456aad2481e48f173f9aa Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Fri, 27 Jun 2025 11:45:28 +0100
Subject: [PATCH 6/6] improve error messages and add tests for gpu-to-spirv
 conversions

---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp |  6 ++-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  8 +++-
 mlir/test/Conversion/GPUToSPIRV/rotate.mlir   | 47 ++++++++++++++++++-
 mlir/test/Dialect/GPU/invalid.mlir            |  4 +-
 4 files changed, 58 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 8b2309cc0fce5..60c241c6ef2d0 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -475,14 +475,16 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
 LogicalResult GPURotateConversion::matchAndRewrite(
     gpu::RotateOp rotateOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
-  auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
+  const spirv::TargetEnv &targetEnv =
+      getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
   unsigned subgroupSize =
       targetEnv.getAttr().getResourceLimits().getSubgroupSize();
   IntegerAttr widthAttr;
   if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
       widthAttr.getValue().getZExtValue() > subgroupSize)
     return rewriter.notifyMatchFailure(
-        rotateOp, "rotate width is larger than target subgroup size");
+        rotateOp,
+        "rotate width is not a constant or larger than target subgroup size");
 
   Location loc = rotateOp.getLoc();
   auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 8eafe84a02720..61d2bc3d93bfa 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1365,8 +1365,12 @@ LogicalResult RotateOp::verify() {
   if (!widthValue.isPowerOf2())
     return emitOpError() << "width must be a power of two";
 
-  if (offsetValue.sge(widthValue) || offsetValue.slt(0))
-    return emitOpError() << "offset must be in the range [0, width)";
+  if (offsetValue.sge(widthValue) || offsetValue.slt(0)) {
+    SmallString<8> widthStr;
+    widthValue.toStringUnsigned(widthStr);
+    return emitOpError() << "offset must be in the range [0, "
+                         << std::string(std::move(widthStr)) << ")";
+  }
 
   return success();
 }
diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
index b0c7005e1510d..b96dd37219b46 100644
--- a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
 
 module attributes {
   gpu.container_module,
@@ -55,3 +55,48 @@ gpu.module @kernels {
 }
 
 }
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
+    #spirv.resource_limits<subgroup_size = 16>>
+} {
+
+gpu.module @kernels {
+  gpu.func @rotate_with_bigger_than_subgroup_size() kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    %offset = arith.constant 4 : i32
+    %width = arith.constant 32 : i32
+    %val = arith.constant 42.0 : f32
+
+    // expected-error @+1 {{failed to legalize operation 'gpu.rotate'}}
+    %result, %valid = gpu.rotate %val, %offset, %width : f32
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
+    #spirv.resource_limits<subgroup_size = 16>>
+} {
+
+gpu.module @kernels {
+  gpu.func @rotate_non_const_width(%width: i32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    %offset = arith.constant 4 : i32
+    %val = arith.constant 42.0 : f32
+
+    // expected-error @+1 {{'gpu.rotate' op width is not a constant value}}
+    %result, %valid = gpu.rotate %val, %offset, %width : f32
+    gpu.return
+  }
+}
+
+}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index b4fe3a51e7973..162ff0662e91e 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -521,7 +521,7 @@ func.func @rotate_unsupported_width(%arg0 : f32) {
 func.func @rotate_unsupported_offset(%arg0 : f32) {
   %offset = arith.constant 16 : i32
   %width = arith.constant 16 : i32
-  // expected-error at +1 {{op offset must be in the range [0, width)}}
+  // expected-error at +1 {{op offset must be in the range [0, 16)}}
   %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
   return
 }
@@ -531,7 +531,7 @@ func.func @rotate_unsupported_offset(%arg0 : f32) {
 func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
   %offset = arith.constant -1 : i32
   %width = arith.constant 16 : i32
-  // expected-error at +1 {{op offset must be in the range [0, width)}}
+  // expected-error at +1 {{op offset must be in the range [0, 16)}}
   %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
   return
 }



More information about the Mlir-commits mailing list