[Mlir-commits] [mlir] [MLIR][NVVM] Add Permute Op (PR #169793)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 27 04:10:18 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Dharuni R Acharya (DharuniRAcharya)

<details>
<summary>Changes</summary>

This patch adds the `permute` op.
Lit tests are added to verify the lowering to the intrinsics. 
Negative tests are also added to check the error-handling of invalid combinations.

PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt

---
Full diff: https://github.com/llvm/llvm-project/pull/169793.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+127) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+50) 
- (added) mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir (+43) 
- (added) mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir (+64) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index d78145d690fc8..0b3ca385d0a78 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1567,6 +1567,133 @@ def NVVM_ElectSyncOp : NVVM_Op<"elect.sync">
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Permute Bytes (Prmt)
+//===----------------------------------------------------------------------===//
+
+// Attributes for the permute operation modes supported by PTX.
+def PermuteModeDefault : I32EnumAttrCase<"DEFAULT", 0, "default">;
+def PermuteModeF4E : I32EnumAttrCase<"F4E", 1, "f4e">;
+def PermuteModeB4E : I32EnumAttrCase<"B4E", 2, "b4e">;
+def PermuteModeRC8 : I32EnumAttrCase<"RC8", 3, "rc8">;
+def PermuteModeECL : I32EnumAttrCase<"ECL", 4, "ecl">;
+def PermuteModeECR : I32EnumAttrCase<"ECR", 5, "ecr">;
+def PermuteModeRC16 : I32EnumAttrCase<"RC16", 6, "rc16">;
+
+def PermuteMode : I32EnumAttr<"PermuteMode", "NVVM permute mode",
+                              [PermuteModeDefault, PermuteModeF4E,
+                               PermuteModeB4E, PermuteModeRC8, PermuteModeECL,
+                               PermuteModeECR, PermuteModeRC16]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+
+def PermuteModeAttr : EnumAttr<NVVM_Dialect, PermuteMode, "permute_mode"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_PermuteOp : NVVM_Op<"prmt", [Pure]>,
+                     Results<(outs I32:$res)>,
+                     Arguments<(ins I32:$lo, Optional<I32>:$hi, I32:$selector,
+                         PermuteModeAttr:$mode)> {
+  let summary = "Permute bytes from two 32-bit registers";
+  let description = [{
+    The `nvvm.prmt` operation constructs a permutation of the
+    bytes of the first one or two operands, selecting based on 
+    the 2 least significant bits of the final operand.
+    
+    The bytes in the first one or two source operands are numbered. 
+    The first source operand (%lo) is numbered {b3, b2, b1, b0}, 
+    in the case of the '``default``', '``f4e``' and '``b4e``' variants, 
+    the second source operand (%hi) is numbered {b7, b6, b5, b4}.
+
+    Modes:
+    - `default`: Index mode         - each nibble in `selector` selects a byte from the 8-byte pool
+    - `f4e`    : Forward 4 extract  - extracts 4 contiguous bytes starting from position in `selector`
+    - `b4e`    : Backward 4 extract - extracts 4 contiguous bytes in reverse order
+    - `rc8`    : Replicate 8        - replicates the lower 8 bits across the 32-bit result
+    - `ecl`    : Edge clamp left    - clamps out-of-range indices to the leftmost valid byte
+    - `ecr`    : Edge clamp right   - clamps out-of-range indices to the rightmost valid byte  
+    - `rc16`   : Replicate 16       - replicates the lower 16 bits across the 32-bit result
+
+    Depending on the 2 least significant bits of the %selector operand, the result
+    of the permutation is defined as follows:
+
+    +------------+----------------+--------------+
+    |    Mode    | %selector[1:0] |    Output    |
+    +------------+----------------+--------------+
+    | '``f4e``'  | 0              | {3, 2, 1, 0} |
+    |            +----------------+--------------+
+    |            | 1              | {4, 3, 2, 1} |
+    |            +----------------+--------------+
+    |            | 2              | {5, 4, 3, 2} |
+    |            +----------------+--------------+
+    |            | 3              | {6, 5, 4, 3} |
+    +------------+----------------+--------------+
+    | '``b4e``'  | 0              | {5, 6, 7, 0} |
+    |            +----------------+--------------+
+    |            | 1              | {6, 7, 0, 1} |
+    |            +----------------+--------------+
+    |            | 2              | {7, 0, 1, 2} |
+    |            +----------------+--------------+
+    |            | 3              | {0, 1, 2, 3} |
+    +------------+----------------+--------------+
+    | '``rc8``'  | 0              | {0, 0, 0, 0} |
+    |            +----------------+--------------+
+    |            | 1              | {1, 1, 1, 1} |
+    |            +----------------+--------------+
+    |            | 2              | {2, 2, 2, 2} |
+    |            +----------------+--------------+
+    |            | 3              | {3, 3, 3, 3} |
+    +------------+----------------+--------------+
+    | '``ecl``'  | 0              | {3, 2, 1, 0} |
+    |            +----------------+--------------+
+    |            | 1              | {3, 2, 1, 1} |
+    |            +----------------+--------------+
+    |            | 2              | {3, 2, 2, 2} |
+    |            +----------------+--------------+
+    |            | 3              | {3, 3, 3, 3} |
+    +------------+----------------+--------------+
+    | '``ecr``'  | 0              | {0, 0, 0, 0} |
+    |            +----------------+--------------+
+    |            | 1              | {1, 1, 1, 0} |
+    |            +----------------+--------------+
+    |            | 2              | {2, 2, 1, 0} |
+    |            +----------------+--------------+
+    |            | 3              | {3, 2, 1, 0} |
+    +------------+----------------+--------------+
+    | '``rc16``' | 0              | {1, 0, 1, 0} |
+    |            +----------------+--------------+
+    |            | 1              | {3, 2, 3, 2} |
+    |            +----------------+--------------+
+    |            | 2              | {1, 0, 1, 0} |
+    |            +----------------+--------------+
+    |            | 3              | {3, 2, 3, 2} |
+    +------------+----------------+--------------+
+    
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt)
+  }];
+
+  let assemblyFormat = [{
+    $mode $lo `,` $selector (`,` $hi^)? attr-dict `:` type($res)
+  }];
+
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(NVVM::PermuteMode mode, llvm::Value *lo, 
+                          llvm::Value *hi, llvm::Value *selector);
+  }];
+
+  string llvmBuilder = [{
+    auto [id, args] = NVVM::PermuteOp::getIntrinsicIDAndArgs(
+        $mode, $lo, $hi, $selector);
+    $res = createIntrinsicCall(builder, id, args);
+  }];
+}
+
 def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">;
 def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">;
 def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 428bc72c88a30..cfe5a98929467 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -448,6 +448,33 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
   return success();
 }
 
+LogicalResult PermuteOp::verify() {
+  using Mode = NVVM::PermuteMode;
+  bool hasHi = static_cast<bool>(getHi());
+
+  switch (getMode()) {
+  case Mode::DEFAULT:
+  case Mode::F4E:
+  case Mode::B4E:
+    if (!hasHi)
+      return emitError("mode '") << stringifyPermuteMode(getMode())
+                                 << "' requires 'hi' operand i.e. it requires "
+                                    "3 operands - lo, hi, selector.";
+    break;
+  case Mode::RC8:
+  case Mode::ECL:
+  case Mode::ECR:
+  case Mode::RC16:
+    if (hasHi)
+      return emitError("mode '") << stringifyPermuteMode(getMode())
+                                 << "' does not accept 'hi' operand i.e. it "
+                                    "requires 2 operands - lo, selector.";
+    break;
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Stochastic Rounding Conversion Ops
 //===----------------------------------------------------------------------===//
@@ -3379,6 +3406,29 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
   return {intrinsicID, args};
 }
 
+mlir::NVVM::IDArgPair PermuteOp::getIntrinsicIDAndArgs(NVVM::PermuteMode mode,
+                                                       llvm::Value *lo,
+                                                       llvm::Value *hi,
+                                                       llvm::Value *selector) {
+  static constexpr llvm::Intrinsic::ID IDs[] = {
+      llvm::Intrinsic::nvvm_prmt,     llvm::Intrinsic::nvvm_prmt_f4e,
+      llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
+      llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
+      llvm::Intrinsic::nvvm_prmt_rc16};
+
+  unsigned modeIndex = static_cast<unsigned>(mode);
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(lo);
+
+  if (modeIndex < 3)
+    args.push_back(hi);
+
+  args.push_back(selector);
+
+  return {IDs[modeIndex], args};
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM tcgen05.mma functions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir
new file mode 100644
index 0000000000000..06beec2e4b78b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+llvm.func @invalid_default_missing_hi(%lo: i32, %sel: i32) -> i32 {
+  // expected-error @below {{mode 'default' requires 'hi' operand i.e. it requires 3 operands - lo, hi, selector}}
+  %r = nvvm.prmt #nvvm.permute_mode<default> %lo, %sel : i32
+  llvm.return %r : i32
+}
+
+llvm.func @invalid_f4e_missing_hi(%lo: i32, %sel: i32) -> i32 {
+  // expected-error @below {{mode 'f4e' requires 'hi' operand i.e. it requires 3 operands - lo, hi, selector}}
+  %r = nvvm.prmt #nvvm.permute_mode<f4e> %lo, %sel : i32
+  llvm.return %r : i32
+}
+
+llvm.func @invalid_b4e_missing_hi(%lo: i32, %sel: i32) -> i32 {
+  // expected-error @below {{mode 'b4e' requires 'hi' operand i.e. it requires 3 operands - lo, hi, selector}}
+  %r = nvvm.prmt #nvvm.permute_mode<b4e> %lo, %sel : i32
+  llvm.return %r : i32
+}
+
+llvm.func @invalid_rc8_with_hi(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+  // expected-error @below {{mode 'rc8' does not accept 'hi' operand i.e. it requires 2 operands - lo, selector}}
+  %r = nvvm.prmt #nvvm.permute_mode<rc8> %lo, %sel, %hi : i32
+  llvm.return %r : i32
+}
+
+llvm.func @invalid_ecl_with_hi(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+  // expected-error @below {{mode 'ecl' does not accept 'hi' operand i.e. it requires 2 operands - lo, selector}}
+  %r = nvvm.prmt #nvvm.permute_mode<ecl> %lo, %sel, %hi : i32
+  llvm.return %r : i32
+}
+
+llvm.func @invalid_ecr_with_hi(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+  // expected-error @below {{mode 'ecr' does not accept 'hi' operand i.e. it requires 2 operands - lo, selector}}
+  %r = nvvm.prmt #nvvm.permute_mode<ecr> %lo, %sel, %hi : i32
+  llvm.return %r : i32
+}
+
+llvm.func @invalid_rc16_with_hi(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+  // expected-error @below {{mode 'rc16' does not accept 'hi' operand i.e. it requires 2 operands - lo, selector}}
+  %r = nvvm.prmt #nvvm.permute_mode<rc16> %lo, %sel, %hi : i32
+  llvm.return %r : i32
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir
new file mode 100644
index 0000000000000..cfe150bfbac01
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @test_prmt_default
+llvm.func @test_prmt_default(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+  // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %result = nvvm.prmt #nvvm.permute_mode<default> %lo, %sel, %hi : i32
+  llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_f4e
+llvm.func @test_prmt_f4e(%lo: i32, %pos: i32, %hi: i32) -> i32 {
+  // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %result = nvvm.prmt #nvvm.permute_mode<f4e> %lo, %pos, %hi : i32
+  llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_b4e
+llvm.func @test_prmt_b4e(%lo: i32, %pos: i32, %hi: i32) -> i32 {
+  // CHECK: call i32 @llvm.nvvm.prmt.b4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %result = nvvm.prmt #nvvm.permute_mode<b4e> %lo, %pos, %hi : i32
+  llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_rc8
+llvm.func @test_prmt_rc8(%val: i32, %sel: i32) -> i32 {
+  // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}})
+  %result = nvvm.prmt #nvvm.permute_mode<rc8> %val, %sel : i32
+  llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_ecl
+llvm.func @test_prmt_ecl(%val: i32, %sel: i32) -> i32 {
+  // CHECK: call i32 @llvm.nvvm.prmt.ecl(i32 %{{.*}}, i32 %{{.*}})
+  %result = nvvm.prmt #nvvm.permute_mode<ecl> %val, %sel : i32
+  llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_ecr
+llvm.func @test_prmt_ecr(%val: i32, %sel: i32) -> i32 {
+  // CHECK: call i32 @llvm.nvvm.prmt.ecr(i32 %{{.*}}, i32 %{{.*}})
+  %result = nvvm.prmt #nvvm.permute_mode<ecr> %val, %sel : i32
+  llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_rc16
+llvm.func @test_prmt_rc16(%val: i32, %sel: i32) -> i32 {
+  // CHECK: call i32 @llvm.nvvm.prmt.rc16(i32 %{{.*}}, i32 %{{.*}})
+  %result = nvvm.prmt #nvvm.permute_mode<rc16> %val, %sel : i32
+  llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_mixed
+llvm.func @test_prmt_mixed(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+  // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r1 = nvvm.prmt #nvvm.permute_mode<default> %lo, %sel, %hi : i32
+
+  // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}})
+  %r2 = nvvm.prmt #nvvm.permute_mode<rc8> %r1, %sel : i32
+
+  // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r3 = nvvm.prmt #nvvm.permute_mode<f4e> %r2, %lo, %sel : i32
+
+  llvm.return %r3 : i32
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/169793


More information about the Mlir-commits mailing list