[Mlir-commits] [mlir] [MLIR][AMDGPU] Introduce fp16 packed arithmetic (PR #105688)

Giuseppe Rossini llvmlistbot at llvm.org
Fri Aug 23 05:57:00 PDT 2024


https://github.com/giuseros updated https://github.com/llvm/llvm-project/pull/105688

>From 169dcaf799160c994d8e0091bc09cc34ac655125 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Thu, 22 Aug 2024 17:14:42 +0100
Subject: [PATCH 1/2] [MLIR][AMDGPU] Introduce fp16 packed arithmetic

---
 .../Conversion/ArithToAMDGPU/ArithToAMDGPU.h  |   7 +-
 mlir/include/mlir/Conversion/Passes.td        |   6 +
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |   1 +
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td  |  17 ++-
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           | 110 ++++++++++++++++--
 .../Conversion/ArithToAMDGPU/CMakeLists.txt   |   1 +
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |   1 +
 mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt     |   1 +
 .../ArithToAMDGPU/16-bit-floats.mlir          |  51 ++++++++
 .../ArithToAMDGPU/8-bit-float-saturation.mlir |   2 +-
 .../ArithToAMDGPU/8-bit-floats.mlir           |   2 +-
 mlir/test/Target/LLVMIR/rocdl.mlir            |   6 +
 12 files changed, 194 insertions(+), 11 deletions(-)
 create mode 100644 mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir

diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 78c79c915e0607..28fdc234e5ef07 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -9,7 +9,9 @@
 #ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
 #define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
 
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include <memory>
+#include <string>
 
 namespace mlir {
 
@@ -26,7 +28,10 @@ namespace arith {
 /// to the largest value of that type instead of being rewritten to Inf (aka
 /// NaN).
 void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
-                                             bool saturateFP8TruncF);
+                                             bool convertFP8Arithmetic,
+                                             bool saturateFP8Truncf,
+                                             bool allowPackedF16Rtz,
+                                             amdgpu::Chipset chipset);
 } // namespace arith
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961c..24dc3b67db5a56 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -150,9 +150,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
   let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
 
   let options = [
+    Option<"chipset", "chipset", "std::string",
+                        /*default=*/"\"gfx000\"",
+                        "Chipset that these operations will run on">,
     Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
            /*default=*/"false",
            "Use saturating truncation for 8-bit float types">,
+    Option<"allowPackedF16Rtz", "allow-packed-f16-round-to-zero", "bool",
+           /*default=*/"false",
+           "Whether we should allow f32->f16 packed round-to-zero conversion">,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index aa2b4543927a7f..d6fcf7329b6099 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -25,6 +25,7 @@ def AMDGPU_Dialect : Dialect {
 
 
   let dependentDialects = [
+    "ROCDL::ROCDLDialect",
     "arith::ArithDialect",
     "gpu::GPUDialect"
   ];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a521..082148ddb13d6f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
   let summary = "Vote across thread group";
 
   let description = [{
-      Ballot provides a bit mask containing the 1-bit predicate value from each lane. 
+      Ballot provides a bit mask containing the 1-bit predicate value from each lane.
       The nth bit of the result contains the 1 bit contributed by the nth warp lane.
   }];
 
@@ -554,6 +554,21 @@ def ROCDL_RawBufferAtomicUMinOp :
   let hasCustomAssemblyFormat = 1;
 }
 
+//===---------------------------------------------------------------------===//
+// 16-bit float intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtPkRtz:
+    ROCDL_IntrOp<"cvt.pkrtz", [], [], [Pure], 1>,
+    Arguments<(ins F32:$srcA, F32:$srcB)> {
+  let summary = "Convert two f32 input into a vector<2xf16>";
+  let description = [{
+    Convert two f32 values into a packed vector<2xf16>.
+  }];
+  let assemblyFormat = [{
+    attr-dict $srcA `,` $srcB `:` type($res)
+  }];
+}
+
 //===---------------------------------------------------------------------===//
 // 8-bit float intrinsics
 //===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b3798a3f7624b0..5c37ec536d8963 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -9,8 +9,11 @@
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
@@ -24,6 +27,7 @@ namespace mlir {
 } // namespace mlir
 
 using namespace mlir;
+using namespace mlir::amdgpu;
 
 namespace {
 struct ArithToAMDGPUConversionPass final
@@ -43,12 +47,25 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
 
 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
   bool saturateFP8 = false;
-  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
-      : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
+  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
+                               Chipset chipset)
+      : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
+        chipset(chipset) {}
+  Chipset chipset;
 
   LogicalResult match(arith::TruncFOp op) const override;
   void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
 };
+
+struct TruncfToFloat16RewritePattern final
+    : public OpRewritePattern<arith::TruncFOp> {
+
+  using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+
+  LogicalResult match(arith::TruncFOp op) const override;
+  void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
+};
+
 } // end namespace
 
 static Value castF32To(Type elementType, Value f32, Location loc,
@@ -272,17 +289,96 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
   rewriter.replaceOp(op, result);
 }
 
+LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
+  Type outType = op.getOut().getType();
+  Type inputType = getElementTypeOrSelf(op.getIn());
+  if (auto outVecType = dyn_cast<VectorType>(outType)) {
+    if (outVecType.isScalable())
+      return failure();
+    if (outVecType.getShape().size() > 1)
+      // Multi-dimensional vectors are currently unsupported.
+      return failure();
+    outType = outVecType.getElementType();
+  }
+  return success(outType.isF16() && inputType.isF32());
+}
+
+void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
+                                            PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  Value in = op.getIn();
+  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+  VectorType truncResType = VectorType::get(2, outElemType);
+
+  // Handle the case where input type is not a vector type
+  if (!isa<VectorType>(in.getType())) {
+    auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+    Value asF16s =
+        rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
+    Value result = rewriter.create<vector::ExtractElementOp>(
+        loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+    return rewriter.replaceOp(op, result);
+  }
+  VectorType outType = cast<VectorType>(op.getOut().getType());
+  int64_t numElements = outType.getNumElements();
+  Value zero = rewriter.createOrFold<arith::ConstantOp>(
+      loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+
+  // Handle the vector case. We also handle the (uncommon) case where the vector
+  // length is odd
+  for (int64_t i = 0; i < numElements; i += 2) {
+    int64_t elemsThisOp = std::min(numElements, i + 2) - i;
+    Value thisResult = nullptr;
+    Value elemA = rewriter.create<vector::ExtractElementOp>(
+        loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+    Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+
+    if (elemsThisOp == 2) {
+      elemB = rewriter.create<vector::ExtractElementOp>(
+          loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+    }
+
+    thisResult =
+        rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
+    // Place back the truncated result into the possibly larger vector. If we
+    // are operating on a size 2 vector, these operations should be folded away
+    thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
+        loc, thisResult, 0, elemsThisOp, 1);
+    result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
+                                                           result, i, 1);
+  }
+  rewriter.replaceOp(op, result);
+}
+
 void mlir::arith::populateArithToAMDGPUConversionPatterns(
-    RewritePatternSet &patterns, bool saturateFP8TruncF) {
-  patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
-  patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
-                                             saturateFP8TruncF);
+    RewritePatternSet &patterns, bool convertFP8Arithmetic,
+    bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
+
+  if (convertFP8Arithmetic) {
+    patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+    patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
+                                               saturateFP8Truncf, chipset);
+  }
+  if (allowPackedF16Rtz)
+    patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
 }
 
 void ArithToAMDGPUConversionPass::runOnOperation() {
   Operation *op = getOperation();
+  MLIRContext *ctx = &getContext();
   RewritePatternSet patterns(op->getContext());
-  arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
+  FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
+  if (failed(maybeChipset)) {
+    emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
+    return signalPassFailure();
+  }
+
+  bool convertFP8Arithmetic =
+      (*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
+  arith::populateArithToAMDGPUConversionPatterns(
+      patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
+      *maybeChipset);
   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
     return signalPassFailure();
 }
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
index e2c951b0b34d8b..50be09ab5a7c5b 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
 
   LINK_LIBS PUBLIC
   MLIRAMDGPUDialect
+  MLIRAMDGPUUtils
   MLIRArithDialect
   MLIRArithUtils
   MLIRVectorDialect
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..0b1dd79ded3a71 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 0551d13b5a0cf0..78d78cf48a747c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRROCDLDialect
   # Needed for GPU address space enum definition
   MLIRGPUDialect
   MLIRIR
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
new file mode 100644
index 00000000000000..121cae26748a82
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="allow-packed-f16-round-to-zero=true" | FileCheck %s
+
+// CHECK-LABEL: @scalar_trunc
+// CHECK-SAME: (%[[value:.*]]: f32)
+func.func @scalar_trunc(%v: f32) -> f16{
+  // CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
+  // CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
+  // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+  // CHECK: return %[[extract]] : f16
+  %w = arith.truncf %v : f32 to f16
+  return %w : f16
+}
+
+// CHECK-LABEL: @vector_trunc
+// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
+  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+  // CHECK: return %[[ret]]
+  %w = arith.truncf %v : vector<2xf32> to vector<2xf16>
+  return %w : vector<2xf16>
+}
+
+// CHECK-LABEL:  @vector_trunc_long
+// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
+func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
+  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
+  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+  // CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+  // CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
+  // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+  // CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
+  // CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
+  // CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
+  // CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
+  // CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
+  // CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
+  // CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+  // CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>
+  // CHECK: return %[[out4]]
+  %w = arith.truncf %v : vector<9xf32> to vector<9xf16>
+  return %w : vector<9xf16>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
index c7f39440a349b0..cd921da2294e13 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt --split-input-file %s \
-// RUN:   --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
+// RUN:   --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx940 saturate-fp8-truncf=true}))' \
 // RUN:   | FileCheck %s
 
 // CHECK-LABEL: func.func @scalar_trunc
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 26a222a4a788e5..bd90facb615440 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx940" | FileCheck %s
 
 // CHECK-LABEL: func.func @scalar_ext
 // CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 78c3987fab648e..d04978ff6deeb7 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -516,6 +516,12 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
   llvm.return %source5 : i32
 }
 
+llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
+  // CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})
+  %source = rocdl.cvt.pkrtz %sourceA, %sourceB  : vector<2xf16>
+  llvm.return %source : vector<2xf16>
+}
+
 // CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
 // CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
 // CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"

>From 42eecb561570e1ff1dd5d042a7db338f0bc3e55b Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Fri, 23 Aug 2024 13:56:40 +0100
Subject: [PATCH 2/2] Address review feedback

---
 .../Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp  | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 5c37ec536d8963..d36583c8118ff4 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -295,9 +295,6 @@ LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
   if (auto outVecType = dyn_cast<VectorType>(outType)) {
     if (outVecType.isScalable())
       return failure();
-    if (outVecType.getShape().size() > 1)
-      // Multi-dimensional vectors are currently unsupported.
-      return failure();
     outType = outVecType.getElementType();
   }
   return success(outType.isF16() && inputType.isF32());
@@ -309,9 +306,10 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
   VectorType truncResType = VectorType::get(2, outElemType);
+  auto inVectorTy = dyn_cast<VectorType>(in.getType());
 
   // Handle the case where input type is not a vector type
-  if (!isa<VectorType>(in.getType())) {
+  if (!inVectorTy) {
     auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
     Value asF16s =
         rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
@@ -325,6 +323,12 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
   Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
 
+  if (inVectorTy.getRank() > 1) {
+    inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
+                                 inVectorTy.getElementType());
+    in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+  }
+
   // Handle the vector case. We also handle the (uncommon) case where the vector
   // length is odd
   for (int64_t i = 0; i < numElements; i += 2) {
@@ -348,6 +352,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
     result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
                                                            result, i, 1);
   }
+
+  if (inVectorTy.getRank() != outType.getRank()) {
+    result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+  }
+
   rewriter.replaceOp(op, result);
 }
 



More information about the Mlir-commits mailing list