[Mlir-commits] [mlir] [mlir] AMDGPUToROCDL: lower `amdgpu.swizzle_bitmode` (PR #136223)

Ivan Butygin llvmlistbot at llvm.org
Thu Apr 17 16:23:13 PDT 2025


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/136223

Repack `amdgpu.swizzle_bitmode` arguments and lower it to `rocdl.ds_swizzle`.

Repacking logic is follows:
* `sizeof(arg) < sizeof(i32)`: bitcast to integer and zext to i32 and then trunc and bitcast back.
* `sizeof(arg) == sizeof(i32)`: just bitcast to i32 and back if not i32
* `sizeof(arg) > sizeof(i32)`: bitcast to `vector<Nxi32>`, extract individual elements and do a series of `rocdl.ds_swizzle` and then compose vector and bitcast back.

Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for `gpu.shuffle` later.

>From 0916901c87442ddf2bf529870ef486f16a16b4b4 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 13 Apr 2025 06:26:34 +0200
Subject: [PATCH] [mlir] AMDGPUToROCDL: lower `amdgpu.swizzle_bitmode`

Repack `amdgpu.swizzle_bitmode` arguments and lower it to `rocdl.ds_swizzle`.

Repacking logic is follows:
* `sizeof(arg) < sizeof(i32)`: bitcast to integer and zext to i32 and then trunc and bitcast back.
* `sizeof(arg) == sizeof(i32)`: just bitcast to i32 and back if not i32
* `sizeof(arg) > sizeof(i32)`: bitcast to `vector<Nxi32>`, extract individual elements and do a series of `rocdl.ds_swizzle` and then compose vector and bitcast back.

Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for `gpu.shuffle` later.
---
 .../mlir/Conversion/LLVMCommon/Pattern.h      | 12 +++
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 33 +++++++
 mlir/lib/Conversion/LLVMCommon/Pattern.cpp    | 92 +++++++++++++++++++
 .../Conversion/AMDGPUToROCDL/swizzle.mlir     | 75 +++++++++++++++
 4 files changed, 212 insertions(+)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 66c8731ec2bf4..7a58e4fc2f984 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -31,6 +31,18 @@ LogicalResult oneToOneRewrite(
     IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
 
 } // namespace detail
+
+/// Decomposes a `src` value into a set of values of type `dstType` through
+/// series of bitcasts and vector ops. Src and dst types are expected to be int
+/// or float types or vector types of them.
+SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
+                                  Type dstType);
+
+/// Composes a set of `src` values into a single value of type `dstType` through
+/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
+/// function is used to combine multiple values into a single value.
+Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
+                   Type dstType);
 } // namespace LLVM
 
 /// Base class for operation conversions targeting the LLVM IR dialect. It
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5f697bdeef566..5c4c95699142f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1377,6 +1377,38 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
   }
 };
 
+struct AMDGPUSwizzleBitModeLowering
+    : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type i32 = rewriter.getI32Type();
+    Value src = adaptor.getSrc();
+    SmallVector<Value> decomposed =
+        LLVM::decomposeValue(rewriter, loc, src, i32);
+    unsigned andMask = op.getAndMask();
+    unsigned orMask = op.getOrMask();
+    unsigned xorMask = op.getXorMask();
+
+    // bit 15 is 0 for the BitMode swizzle.
+    unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
+    Value maskValue = createI32Constant(rewriter, loc, mask);
+    SmallVector<Value> swizzled;
+    for (Value v : decomposed) {
+      Value res =
+          rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
+      swizzled.emplace_back(res);
+    }
+
+    Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct ConvertAMDGPUToROCDLPass
     : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
   using Base::Base;
@@ -1444,4 +1476,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
            PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
            GatherToLDSOpLowering>(converter, chipset);
+  patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 32bfd72475569..d2737c56369d1 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -381,3 +381,95 @@ LogicalResult LLVM::detail::oneToOneRewrite(
   rewriter.replaceOp(op, results);
   return success();
 }
+
+static unsigned getBitWidth(Type type) {
+  if (type.isIntOrFloat())
+    return type.getIntOrFloatBitWidth();
+
+  auto vec = cast<VectorType>(type);
+  return vec.getNumElements() * getBitWidth(vec.getElementType());
+}
+
+static Value createI32Constant(OpBuilder &builder, Location loc,
+                               int32_t value) {
+  Type i32 = builder.getI32Type();
+  return builder.create<LLVM::ConstantOp>(loc, i32, value);
+}
+
+SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
+                                              Value src, Type dstType) {
+  Type srcType = src.getType();
+  if (srcType == dstType)
+    return {src};
+
+  unsigned srcBitWidth = getBitWidth(srcType);
+  unsigned dstBitWidth = getBitWidth(dstType);
+  if (srcBitWidth == dstBitWidth) {
+    Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
+    return {cast};
+  }
+
+  if (dstBitWidth > srcBitWidth) {
+    auto smallerInt = builder.getIntegerType(srcBitWidth);
+    if (srcType != smallerInt)
+      src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
+
+    auto largerInt = builder.getIntegerType(dstBitWidth);
+    Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
+    return {res};
+  }
+  assert(srcBitWidth % dstBitWidth == 0 &&
+         "src bit width must be a multiple of dst bit width");
+  int64_t numElements = srcBitWidth / dstBitWidth;
+  auto vecType = VectorType::get(numElements, dstType);
+
+  src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
+
+  SmallVector<Value> res;
+  for (auto i : llvm::seq<int64_t>(0, numElements)) {
+    Value idx = createI32Constant(builder, loc, i);
+    Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
+    res.emplace_back(elem);
+  }
+
+  return res;
+}
+
+Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
+                               Type dstType) {
+  assert(!src.empty() && "src range must not be empty");
+  if (src.size() == 1) {
+    Value res = src.front();
+    if (res.getType() == dstType)
+      return res;
+
+    unsigned srcBitWidth = getBitWidth(res.getType());
+    unsigned dstBitWidth = getBitWidth(dstType);
+    if (dstBitWidth < srcBitWidth) {
+      auto largerInt = builder.getIntegerType(srcBitWidth);
+      if (res.getType() != largerInt)
+        res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
+
+      auto smallerInt = builder.getIntegerType(dstBitWidth);
+      res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
+    }
+
+    if (res.getType() != dstType)
+      res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+    return res;
+  }
+
+  int64_t numElements = src.size();
+  auto srcType = VectorType::get(numElements, src.front().getType());
+  Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
+  for (auto &&[i, elem] : llvm::enumerate(src)) {
+    Value idx = createI32Constant(builder, loc, i);
+    res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
+  }
+
+  if (res.getType() != dstType)
+    res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+  return res;
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
new file mode 100644
index 0000000000000..ef439efde1bd0
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: func @test_swizzle_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_swizzle_i32(%arg0 : i32) -> i32 {
+// CHECK:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK:  %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32
+// CHECK:  return %[[RES]] : i32
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: func @test_swizzle_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_swizzle_f32(%arg0 : f32) -> f32 {
+// CHECK:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK:  %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK:  %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK:  return %[[RES_CAST]] : f32
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: func @test_swizzle_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_swizzle_f16(%arg0 : f16) -> f16 {
+// CHECK:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK:  %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK:  %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK:  %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK:  return %[[RES_CAST]] : f16
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: func @test_swizzle_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG:  %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:  %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  return %[[V3]] : vector<2xi32>
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_swizzle_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG:  %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:  %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK:  %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16>
+// CHECK:  return %[[CAST2]] : vector<4xf16>
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16>
+  return %0 : vector<4xf16>
+}



More information about the Mlir-commits mailing list