[Mlir-commits] [llvm] [mlir] [mlir][AMDGPU] Implement AMDGPU DPP operation in MLIR. (PR #89233)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 16 05:40:44 PDT 2024


https://github.com/stefankoncarevic updated https://github.com/llvm/llvm-project/pull/89233

>From 01f72eae9a4da69aec300cd6399d8bccbc26720f Mon Sep 17 00:00:00 2001
From: stefankoncarevic <skoncare at amd.com>
Date: Thu, 18 Apr 2024 13:26:20 +0000
Subject: [PATCH 1/2] [mlir][AMDGPU] Implement AMDGPU DPP operation in MLIR.
 Defined AMDGPU DPP operation in mlir to represent semantics. Introduced a new
 enumeration attribute for different permutations and allowed for different
 types of arguments.Implemented constant attribute handling for
 ROCDL::DPPUpdateOp operation. The operation now correctly accepts constant
 attributes for dppCtrl, rowMask, bankMask, boundCtrl, and passes them to the
 corresponding LLVM intrinsic.

---
 llvm/include/llvm/IR/IntrinsicsAMDGPU.td      |   6 +
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  55 +++++++
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td  |  24 +++
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 148 +++++++++++++++++-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  62 ++++++++
 mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir   | 148 ++++++++++++++++++
 6 files changed, 440 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir

diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 8c25467cc5e4b1..7c93c19a410e43 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2343,6 +2343,8 @@ def int_amdgcn_buffer_wbinvl1_vol :
 // VI Intrinsics
 //===----------------------------------------------------------------------===//
 
+// The llvm.amdgcn.mov.dpp.i32 intrinsic represents the mov.dpp operation in AMDGPU.
+// This operation is being deprecated and can be replaced with llvm.amdgcn.update.dpp.i32.
 // llvm.amdgcn.mov.dpp.i32 <src> <dpp_ctrl> <row_mask> <bank_mask> <bound_ctrl>
 def int_amdgcn_mov_dpp :
   Intrinsic<[llvm_anyint_ty],
@@ -2352,6 +2354,10 @@ def int_amdgcn_mov_dpp :
              ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>,
              ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, IntrNoCallback, IntrNoFree]>;
 
+// The llvm.amdgcn.update.dpp.i32 intrinsic represents the update.dpp operation in AMDGPU.
+// It takes an old value, a source operand, a DPP control operand, a row mask, a bank mask, and a bound control.
+// This operation is equivalent to a sequence of v_mov_b32 operations.
+// It is preferred over llvm.amdgcn.mov.dpp.i32 for future use.
 // llvm.amdgcn.update.dpp.i32 <old> <src> <dpp_ctrl> <row_mask> <bank_mask> <bound_ctrl>
 // Should be equivalent to:
 // v_mov_b32 <dest> <old>
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index dddbd837c1cbc7..97e0580c898080 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -402,6 +402,61 @@ def AMDGPU_RawBufferAtomicUminOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
+    "The possible permutations for a DPP operation",
+    [  
+      I32EnumAttrCase<"quad_perm",  0>,
+      I32EnumAttrCase<"row_shl",    1>,
+      I32EnumAttrCase<"row_shr",    2>,
+      I32EnumAttrCase<"row_ror",    3>,
+      I32EnumAttrCase<"wave_shl",   4>,
+      I32EnumAttrCase<"wave_shr",   5>,
+      I32EnumAttrCase<"wave_ror",   6>,
+      I32EnumAttrCase<"wave_rol",   7>,
+      I32EnumAttrCase<"row_mirror", 8>,
+      I32EnumAttrCase<"row_half_mirror", 9>,
+      I32EnumAttrCase<"row_bcast_15", 10>,
+      I32EnumAttrCase<"row_bcast_31", 11>
+    ]> {
+  let genSpecializedAttr = 0;  
+  let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_DPPPermAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_DPPPerm,
+  "dpp_perm">;
+
+def AMDGPU_DPPOp : AMDGPU_Op<"dpp", [SameTypeOperands, AllTypesMatch<["result", "old", "src"]>]>,
+  Arguments<(ins AnyType:$old,
+                  AnyType:$src,
+                  AMDGPU_DPPPermAttr:$kind,
+                  OptionalAttr<AnyAttrOf<[I32Attr, ArrayAttr, UnitAttr]>>:$permArgument,
+                  DefaultValuedAttr<I32Attr, "0xf">:$row_mask,
+                  DefaultValuedAttr<I32Attr, "0xf">:$bank_mask,
+                  DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl)> {
+  let summary = "AMDGPU DPP operation";
+  let description = [{
+    This operation represents DPP functionality in a GPU program.
+     DPP provides the following operations:
+    - Full crossbar in a group of four (`quad_perm`)
+    - Wavefront shift left by one lane (`wave_shl`)
+    - Wavefront shift right by one lane (`wave_shr`)
+    - Wavefront rotate right by one lane (`wave_ror`)
+    - Wavefront rotate left by one lane (`wave_rol`)
+    - Row shift left by 1–15 lanes (`row_shl`)
+    - Row shift right by 1–15 lanes (`row_shr`)
+    - Row rotate right by 1–15 lanes (`row_ror`)
+    - Reverse within a row (`row_mirror`)
+    - Reverse within a half-row (`row_half_mirror`)
+    - Broadcast the 15th lane of each row to the next row (`row_bcast`)
+    - Broadcast lane 31 to rows 2 and 3 (`row_bcast`)
+  }];
+  let results = (outs AnyType:$result);
+  let assemblyFormat = [{
+    $old $src $kind (`(` $permArgument^ `)`)? attr-dict `:` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
   let summary = "Barrier that includes a wait for LDS memory operations.";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index b81b7e577b213a..217faa237a3eb8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -555,6 +555,30 @@ def ROCDL_RawBufferAtomicUMinOp :
   let hasCustomAssemblyFormat = 1;
 }
 
+// DPP Update intrinsic
+def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
+    [AllTypesMatch<["res", "src", "old"]>], 1>,
+  Arguments<(ins LLVM_Type:$old, LLVM_Type:$src, I32Attr:$dppCtrl, I32Attr:$rowMask,
+      I32Attr:$bankMask, I1Attr:$boundCtrl)> {
+  let results = (outs LLVM_Type:$res);
+  let assemblyFormat = [{
+    attr-dict $old `,` $src `with` $dppCtrl `,` $rowMask `,` $bankMask `,` $boundCtrl `:` type($src)
+  }];
+  string llvmBuilder = [{
+      auto vdataType = moduleTranslation.convertType(op.getSrc().getType());
+      llvm::Value *args[] = {
+        moduleTranslation.lookupValue(op.getOld()),
+        moduleTranslation.lookupValue(op.getSrc()),
+          builder.getInt32(op.getDppCtrl()),
+          builder.getInt32(op.getRowMask()),
+          builder.getInt32(op.getBankMask()),
+          builder.getInt1(op.getBoundCtrl())
+      }; 
+      $res = createIntrinsicCall(builder,
+        llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
+  }];
+}
+
 //===---------------------------------------------------------------------===//
 // 8-bit float intrinsics
 //===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b808738804030f..499ee315b41f97 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -844,6 +844,147 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
   return success();
 }
 
+// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
+// operation into the corresponding ROCDL instructions.
+struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
+  AMDGPUDPPLowering(LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    // Convert the source operand to the corresponding LLVM type
+    Location loc = DppOp.getLoc();
+    Value src = adaptor.getSrc();
+    Value old = adaptor.getOld();
+    Type srcType = src.getType();
+    Type oldType = old.getType();
+    auto llvmI32Type = typeConverter->convertType(rewriter.getI32Type());
+    auto llvmSrcIntType = typeConverter->convertType(
+        rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
+
+    // If the source type is less or equal to i32 or f32, use bitcast to convert
+    // it to i32.
+    auto convertOperand = [&](Value operand, Type operandType) {
+      if (llvm::isa<FloatType>(operandType)) {
+        operand =
+            rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
+      }
+
+      if (operandType.getIntOrFloatBitWidth() < 32) {
+        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
+            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+        Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
+        operand = rewriter.create<LLVM::InsertElementOp>(
+            loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
+        operand = rewriter.create<LLVM::BitcastOp>(loc, llvmI32Type, operand);
+      }
+      return operand;
+    };
+
+    src = convertOperand(src, srcType);
+    old = convertOperand(old, oldType);
+
+    // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
+    enum DppCtrl : unsigned {
+      ROW_SHL0 = 0x100,
+      ROW_SHR0 = 0x110,
+      ROW_ROR0 = 0x120,
+      WAVE_SHL1 = 0x130,
+      WAVE_ROL1 = 0x134,
+      WAVE_SHR1 = 0x138,
+      WAVE_ROR1 = 0x13C,
+      ROW_MIRROR = 0x140,
+      ROW_HALF_MIRROR = 0x141,
+      BCAST15 = 0x142,
+      BCAST31 = 0x143,
+    };
+
+    auto kind = DppOp.getKind();
+    auto permArgument = DppOp.getPermArgument();
+    uint32_t DppCtrl = 0;
+
+    switch (kind) {
+
+    case DPPPerm::quad_perm:
+      if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
+        int32_t i = 0;
+        for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
+          uint32_t num = elem.getInt();
+          DppCtrl |= num << (i * 2);
+          i++;
+        }
+      }
+      break;
+    case DPPPerm::row_shl:
+      if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
+        DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
+      }
+      break;
+    case DPPPerm::row_shr:
+      if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
+        DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
+      }
+      break;
+    case DPPPerm::row_ror:
+      if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
+        DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
+      }
+      break;
+    case DPPPerm::wave_shl:
+      DppCtrl = DppCtrl::WAVE_SHL1;
+      break;
+    case DPPPerm::wave_shr:
+      DppCtrl = DppCtrl::WAVE_SHR1;
+      break;
+    case DPPPerm::wave_rol:
+      DppCtrl = DppCtrl::WAVE_ROL1;
+      break;
+    case DPPPerm::wave_ror:
+      DppCtrl = DppCtrl::WAVE_ROR1;
+      break;
+    case DPPPerm::row_mirror:
+      DppCtrl = DppCtrl::ROW_MIRROR;
+      break;
+    case DPPPerm::row_half_mirror:
+      DppCtrl = DppCtrl::ROW_HALF_MIRROR;
+      break;
+    case DPPPerm::row_bcast_15:
+      DppCtrl = DppCtrl::BCAST15;
+      break;
+    case DPPPerm::row_bcast_31:
+      DppCtrl = DppCtrl::BCAST31;
+      break;
+    }
+
+    // Check for row_mask, bank_mask, bound_ctrl if they exist and create
+    // constants
+    auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
+    auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
+    bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
+
+    // create a ROCDL_DPPMovOp instruction with the appropriate attributes
+    auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
+        loc, llvmI32Type, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
+
+    Value result = dppMovOp.getRes();
+    if (srcType.getIntOrFloatBitWidth() < 32) {
+      result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
+    }
+
+    if (!llvm::isa<IntegerType>(srcType)) {
+      result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
+    }
+
+    // We are replacing the AMDGPU_DPPOp instruction with the new
+    // ROCDL_DPPMovOp instruction
+    rewriter.replaceOp(DppOp, ValueRange(result));
+    return success();
+  }
+};
+
 struct ConvertAMDGPUToROCDLPass
     : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
   ConvertAMDGPUToROCDLPass() = default;
@@ -895,9 +1036,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicUminOp>,
            RawBufferOpLowering<RawBufferAtomicCmpswapOp,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
-           LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering,
-           WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
-           PackedStochRoundFp8OpLowering>(converter, chipset);
+           AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
+           MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
+           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
+                                                                      chipset);
 }
 
 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..08e912d4e43e04 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -326,6 +326,68 @@ LogicalResult MFMAOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// DPPOp
+//===----------------------------------------------------------------------===//
+LogicalResult DPPOp::verify() {
+  Type srcType = getSrc().getType();
+  if (srcType.getIntOrFloatBitWidth() > 32) {
+    return emitOpError("integer and floating point types larger than 32 bits "
+                       "are not supported");
+  }
+
+  DPPPerm kind = getKind();
+  Attribute permArgument = getPermArgument().value_or(Attribute{});
+
+  switch (kind) {
+
+  case DPPPerm::quad_perm: {
+    auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
+    if (!quadPermAttr || quadPermAttr.size() != 4) {
+      return emitOpError("quad_perm attribute must have exactly 4 elements");
+    }
+    for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
+      uint32_t num = elem.getInt();
+      if (num < 0 || num > 3) {
+        return emitOpError(
+            "Each element of quad_perm must be in the range [0, 3]");
+      }
+    }
+  } break;
+
+  case DPPPerm::row_shl:
+  case DPPPerm::row_shr:
+  case DPPPerm::row_ror: {
+    if (!permArgument) {
+      return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
+                         "' value not specified");
+    }
+    if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
+      uint32_t attrValue = intAttr.getInt();
+      if (attrValue < 1 || attrValue > 15) {
+        return emitOpError("Attribute value must be between 1 and 15");
+      }
+    }
+  } break;
+
+  case DPPPerm::wave_shl:
+  case DPPPerm::wave_shr:
+  case DPPPerm::wave_rol:
+  case DPPPerm::wave_ror:
+  case DPPPerm::row_mirror:
+  case DPPPerm::row_half_mirror:
+  case DPPPerm::row_bcast_15:
+  case DPPPerm::row_bcast_31: {
+    if (permArgument && !isa<UnitAttr>(permArgument)) {
+      return emitOpError("Expected unit attribute for permArgument, but found "
+                         "non-trivial argument");
+    }
+    break;
+  }
+  }
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
new file mode 100644
index 00000000000000..c982342c879f5d
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl=chipset=gfx908 %s | FileCheck %s
+// RUN: mlir-opt -convert-amdgpu-to-rocdl=chipset=gfx90a %s | FileCheck %s
+// RUN: mlir-opt -convert-amdgpu-to-rocdl=chipset=gfx942 %s | FileCheck %s
+
+func.func @test_dpp(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK-LABEL: func @test_dpp
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with 257, 10, 15, false : i32
+  // CHECK: return %0 : i32
+  %0 = amdgpu.dpp %arg0 %arg1 row_shl ( 0x1 : i32 ) { row_mask = 0xa : i32, bound_ctrl = false } : i32
+    return %0 : i32
+}
+
+func.func @quad_dpp(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK-LABEL: func @quad_dpp
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with 145, 1, 1, true : i32
+  // CHECK: return %0 : i32
+  %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [1,0,1,2] ) { row_mask = 0x1 : i32, bank_mask = 0x1 : i32, bound_ctrl = true } : i32
+    return %0 : i32
+}
+
+func.func @quad_perm_dpp(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK-LABEL: func @quad_perm_dpp
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i32
+  // CHECK: return %0 : i32
+  %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i32
+    return %0 : i32
+}
+
+func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK-LABEL: func @wave_shr_dpp
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with 312, 10, 1, true : i32
+  // CHECK: return %0 : i32
+  %0 = amdgpu.dpp %arg0 %arg1 wave_shr { row_mask = 0xa : i32, bank_mask = 0x1 : i32, bound_ctrl = true } : i32
+    return %0 : i32
+}
+
+func.func @row_bcast_dpp(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK-LABEL: func @row_bcast_dpp
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : i32
+  // CHECK: return %0 : i32
+  %0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : i32
+    return %0 : i32
+}
+
+func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+  // CHECK-LABEL: func @row_bcast_dpp_f32
+  // CHECK: llvm.bitcast %arg1 : f32 to i32
+  // CHECK: llvm.bitcast %arg0 : f32 to i32
+  // CHECK: rocdl.update.dpp %1, %0 with 322, 15, 15, true : i32
+  // CHECK: llvm.bitcast %2 : i32 to f32
+  // CHECK: return %3 : f32
+  %0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f32
+    return %0 : f32
+}
+
+func.func @row_half_mirror_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK-LABEL: func @row_half_mirror_update_dpp
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with 321, 15, 1, false : i32
+  // CHECK: return %0 : i32
+%0 = amdgpu.dpp %arg0 %arg1 row_half_mirror { bank_mask = 0x1 : i32 } : i32
+    return %0 : i32
+}
+
+func.func @wave_rol_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK-LABEL: func @wave_rol_update_dpp
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with 308, 10, 1, false : i32
+  // CHECK: return %0 : i32
+  %0 = amdgpu.dpp %arg0 %arg1 wave_rol { row_mask = 0xa : i32, bank_mask = 0x1 : i32 } : i32
+    return %0 : i32
+}
+
+func.func @test_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+  // CHECK-LABEL: func @test_dpp_f32
+  // CHECK: llvm.bitcast %arg1 : f32 to i32
+  // CHECK: llvm.bitcast %arg0 : f32 to i32
+  // CHECK: rocdl.update.dpp %1, %0 with 320, 1, 4, true : i32
+  // CHECK: llvm.bitcast %2 : i32 to f32
+  // CHECK: return %3 : f32
+  %0 = amdgpu.dpp %arg0 %arg1 row_mirror { row_mask = 0x1 : i32, bank_mask = 0x4 : i32, bound_ctrl = true } : f32
+    return %0 : f32
+}
+
+func.func @test_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
+  // CHECK-LABEL:  func @test_dpp_f16
+  // CHECK: llvm.bitcast %arg1 : f16 to i16
+  // CHECK: llvm.mlir.undef : vector<2xi16>
+  // CHECK: llvm.mlir.constant(0 : i32) : i32
+  // CHECK: llvm.insertelement %0, %1[%2 : i32] : vector<2xi16>
+  // CHECK: llvm.bitcast %3 : vector<2xi16> to i32
+  // CHECK: llvm.bitcast %arg0 : f16 to i16
+  // CHECK: llvm.mlir.undef : vector<2xi16>
+  // CHECK: llvm.mlir.constant(0 : i32) : i32
+  // CHECK: llvm.insertelement %5, %6[%7 : i32] : vector<2xi16>
+  // CHECK: llvm.bitcast %8 : vector<2xi16> to i32
+  // CHECK: rocdl.update.dpp %9, %4 with 273, 15, 3, false : i32
+  // CHECK: llvm.trunc %10 : i32 to i16
+  // CHECK: llvm.bitcast %11 : i16 to f16
+  // CHECK: return %12 : f16
+  %0 = amdgpu.dpp %arg0 %arg1 row_shr ( 0x1 : i32 ){ bank_mask = 0x3 : i32 } : f16
+    return %0 : f16
+}
+
+func.func @row_shl_dpp_i16(%arg0: i16, %arg1: i16) -> i16 {
+  // CHECK-LABEL: func @row_shl_dpp_i16
+  // CHECK: llvm.mlir.undef : vector<2xi16>
+  // CHECK: llvm.mlir.constant(0 : i32) : i32
+  // CHECK: llvm.insertelement %arg1, %0[%1 : i32] : vector<2xi16>
+  // CHECK: llvm.bitcast %2 : vector<2xi16> to i32
+  // CHECK: llvm.mlir.undef : vector<2xi16>
+  // CHECK: llvm.mlir.constant(0 : i32) : i32
+  // CHECK: llvm.insertelement %arg0, %4[%5 : i32] : vector<2xi16>
+  // CHECK: llvm.bitcast %6 : vector<2xi16> to i32
+  // CHECK: rocdl.update.dpp %7, %3 with 298, 10, 1, false : i32
+  // CHECK: llvm.trunc %8 : i32 to i16
+  // CHECK: return %9 : i16
+  %0 = amdgpu.dpp %arg0 %arg1 row_ror ( 0xa : i32 ) { row_mask = 0xa : i32, bank_mask = 0x1 : i32 } : i16
+    return %0 : i16
+}
+
+func.func @quad_perm_update_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+  // CHECK-LABEL: func @quad_perm_update_dpp_f32
+  // CHECK: llvm.bitcast %arg1 : f32 to i32
+  // CHECK: llvm.bitcast %arg0 : f32 to i32
+  // CHECK: rocdl.update.dpp %1, %0 with  88, 15, 1, false : i32
+  // CHECK: llvm.bitcast %2 : i32 to f32
+  // CHECK: return %3 : f32
+  %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) { bank_mask = 0x1 : i32 } : f32
+    return %0 : f32
+}
+
+func.func @row_bcast_update_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
+  // CHECK-LABEL: func @row_bcast_update_dpp_f16
+  // CHECK: llvm.bitcast %arg1 : f16 to i16
+  // CHECK: llvm.mlir.undef : vector<2xi16>
+  // CHECK: llvm.mlir.constant(0 : i32) : i32
+  // CHECK: llvm.insertelement %0, %1[%2 : i32] : vector<2xi16>
+  // CHECK: llvm.bitcast %arg0 : f16 to i16
+  // CHECK: llvm.mlir.undef : vector<2xi16>
+  // CHECK: llvm.mlir.constant(0 : i32) : i32
+  // CHECK  llvm.insertelement %5, %6[%7 : i32] : vector<2xi16>
+  // CHECK: llvm.bitcast %8 : vector<2xi16> to i32
+  // CHECK: rocdl.update.dpp %9, %4 with 322, 15, 15, true : i32
+  // CHECK: llvm.trunc %10 : i32 to i16
+  // CHECK: llvm.bitcast %11 : i16 to f16
+  // CHECK: return %12 : f16
+  %0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f16
+    return %0 : f16
+}

>From 8141ec75be7ddacb815aec63bcf205624d559dbc Mon Sep 17 00:00:00 2001
From: stefankoncarevic <skoncare at amd.com>
Date: Thu, 16 May 2024 13:30:40 +0000
Subject: [PATCH 2/2] [mlir][AMDGPU] "Added support for 64-bit operands in
 ROCDL::DPPUpdateOp operation."

---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td  |  2 +-
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 38 +++++----
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  4 +-
 mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir   | 77 ++++++++-----------
 4 files changed, 60 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 217faa237a3eb8..a1e6fc3e299009 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -573,7 +573,7 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
           builder.getInt32(op.getRowMask()),
           builder.getInt32(op.getBankMask()),
           builder.getInt1(op.getBoundCtrl())
-      }; 
+      };
       $res = createIntrinsicCall(builder,
         llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
   }];
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 499ee315b41f97..809e9448e80abf 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -861,25 +861,34 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
     Value old = adaptor.getOld();
     Type srcType = src.getType();
     Type oldType = old.getType();
-    auto llvmI32Type = typeConverter->convertType(rewriter.getI32Type());
+    Type llvmType = nullptr;
+    if (srcType.getIntOrFloatBitWidth() < 32) {
+      llvmType = rewriter.getI32Type();
+    } else if (isa<FloatType>(srcType)) {
+      llvmType = (srcType.getIntOrFloatBitWidth() == 32)
+                     ? rewriter.getF32Type()
+                     : rewriter.getF64Type();
+    } else if (isa<IntegerType>(srcType)) {
+      llvmType = (srcType.getIntOrFloatBitWidth() == 32)
+                     ? rewriter.getI32Type()
+                     : rewriter.getI64Type();
+    }
     auto llvmSrcIntType = typeConverter->convertType(
         rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
 
-    // If the source type is less or equal to i32 or f32, use bitcast to convert
-    // it to i32.
+    // If the source type is less of 32, use bitcast to convert it to i32.
     auto convertOperand = [&](Value operand, Type operandType) {
-      if (llvm::isa<FloatType>(operandType)) {
-        operand =
-            rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
-      }
-
-      if (operandType.getIntOrFloatBitWidth() < 32) {
+      if (operandType.getIntOrFloatBitWidth() <= 16) {
+        if (llvm::isa<FloatType>(operandType)) {
+          operand =
+              rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
+        }
         auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
             32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
         operand = rewriter.create<LLVM::InsertElementOp>(
             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
-        operand = rewriter.create<LLVM::BitcastOp>(loc, llvmI32Type, operand);
+        operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
       }
       return operand;
     };
@@ -967,15 +976,14 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
 
     // create a ROCDL_DPPMovOp instruction with the appropriate attributes
     auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
-        loc, llvmI32Type, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
+        loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
 
     Value result = dppMovOp.getRes();
     if (srcType.getIntOrFloatBitWidth() < 32) {
       result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
-    }
-
-    if (!llvm::isa<IntegerType>(srcType)) {
-      result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
+      if (!llvm::isa<IntegerType>(srcType)) {
+        result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
+      }
     }
 
     // We are replacing the AMDGPU_DPPOp instruction with the new
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 08e912d4e43e04..c1a785fb25478d 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -331,8 +331,8 @@ LogicalResult MFMAOp::verify() {
 //===----------------------------------------------------------------------===//
 LogicalResult DPPOp::verify() {
   Type srcType = getSrc().getType();
-  if (srcType.getIntOrFloatBitWidth() > 32) {
-    return emitOpError("integer and floating point types larger than 32 bits "
+  if (srcType.getIntOrFloatBitWidth() > 64) {
+    return emitOpError("integer and floating point types larger than 64 bits "
                        "are not supported");
   }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
index c982342c879f5d..14691e73e62d78 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
@@ -18,14 +18,6 @@ func.func @quad_dpp(%arg0: i32, %arg1: i32) -> i32 {
     return %0 : i32
 }
 
-func.func @quad_perm_dpp(%arg0: i32, %arg1: i32) -> i32 {
-  // CHECK-LABEL: func @quad_perm_dpp
-  // CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i32
-  // CHECK: return %0 : i32
-  %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i32
-    return %0 : i32
-}
-
 func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
   // CHECK-LABEL: func @wave_shr_dpp
   // CHECK: rocdl.update.dpp %arg0, %arg1 with 312, 10, 1, true : i32
@@ -34,25 +26,6 @@ func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
     return %0 : i32
 }
 
-func.func @row_bcast_dpp(%arg0: i32, %arg1: i32) -> i32 {
-  // CHECK-LABEL: func @row_bcast_dpp
-  // CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : i32
-  // CHECK: return %0 : i32
-  %0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : i32
-    return %0 : i32
-}
-
-func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
-  // CHECK-LABEL: func @row_bcast_dpp_f32
-  // CHECK: llvm.bitcast %arg1 : f32 to i32
-  // CHECK: llvm.bitcast %arg0 : f32 to i32
-  // CHECK: rocdl.update.dpp %1, %0 with 322, 15, 15, true : i32
-  // CHECK: llvm.bitcast %2 : i32 to f32
-  // CHECK: return %3 : f32
-  %0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f32
-    return %0 : f32
-}
-
 func.func @row_half_mirror_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
   // CHECK-LABEL: func @row_half_mirror_update_dpp
   // CHECK: rocdl.update.dpp %arg0, %arg1 with 321, 15, 1, false : i32
@@ -69,17 +42,46 @@ func.func @wave_rol_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
     return %0 : i32
 }
 
+func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+  // CHECK-LABEL: func @row_bcast_dpp_f32
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with 322, 15, 15, true : f32
+  // CHECK: return %0 : f32
+  %0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f32
+    return %0 : f32
+}
+
 func.func @test_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
   // CHECK-LABEL: func @test_dpp_f32
-  // CHECK: llvm.bitcast %arg1 : f32 to i32
-  // CHECK: llvm.bitcast %arg0 : f32 to i32
-  // CHECK: rocdl.update.dpp %1, %0 with 320, 1, 4, true : i32
-  // CHECK: llvm.bitcast %2 : i32 to f32
-  // CHECK: return %3 : f32
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with 320, 1, 4, true : f32
+  // CHECK: return %0 : f32
   %0 = amdgpu.dpp %arg0 %arg1 row_mirror { row_mask = 0x1 : i32, bank_mask = 0x4 : i32, bound_ctrl = true } : f32
     return %0 : f32
 }
 
+func.func @quad_perm_update_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+  // CHECK-LABEL: func @quad_perm_update_dpp_f32
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with  88, 15, 1, false : f32
+  // CHECK: return %0 : f32
+  %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) { bank_mask = 0x1 : i32 } : f32
+    return %0 : f32
+}
+
+func.func @quad_perm_dpp(%arg0: i64, %arg1: i64) -> i64 {
+  // CHECK-LABEL: func @quad_perm_dpp
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i64
+  // CHECK: return %0 : i64
+  %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i64
+    return %0 : i64
+}
+
+func.func @row_bcast_dpp(%arg0: f64, %arg1: f64) -> f64 {
+  // CHECK-LABEL: func @row_bcast_dpp
+  // CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : f64
+  // CHECK: return %0 : f64
+  %0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : f64
+    return %0 : f64
+}
+
 func.func @test_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
   // CHECK-LABEL:  func @test_dpp_f16
   // CHECK: llvm.bitcast %arg1 : f16 to i16
@@ -117,17 +119,6 @@ func.func @row_shl_dpp_i16(%arg0: i16, %arg1: i16) -> i16 {
     return %0 : i16
 }
 
-func.func @quad_perm_update_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
-  // CHECK-LABEL: func @quad_perm_update_dpp_f32
-  // CHECK: llvm.bitcast %arg1 : f32 to i32
-  // CHECK: llvm.bitcast %arg0 : f32 to i32
-  // CHECK: rocdl.update.dpp %1, %0 with  88, 15, 1, false : i32
-  // CHECK: llvm.bitcast %2 : i32 to f32
-  // CHECK: return %3 : f32
-  %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) { bank_mask = 0x1 : i32 } : f32
-    return %0 : f32
-}
-
 func.func @row_bcast_update_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
   // CHECK-LABEL: func @row_bcast_update_dpp_f16
   // CHECK: llvm.bitcast %arg1 : f16 to i16



More information about the Mlir-commits mailing list