[Mlir-commits] [mlir] [MLIR][AMDGPU] Introduce fp16 packed arithmetic (PR #105688)
Giuseppe Rossini
llvmlistbot at llvm.org
Fri Aug 23 05:57:00 PDT 2024
https://github.com/giuseros updated https://github.com/llvm/llvm-project/pull/105688
>From 169dcaf799160c994d8e0091bc09cc34ac655125 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Thu, 22 Aug 2024 17:14:42 +0100
Subject: [PATCH 1/2] [MLIR][AMDGPU] Introduce fp16 packed arithmetic
---
.../Conversion/ArithToAMDGPU/ArithToAMDGPU.h | 7 +-
mlir/include/mlir/Conversion/Passes.td | 6 +
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 1 +
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 17 ++-
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 110 ++++++++++++++++--
.../Conversion/ArithToAMDGPU/CMakeLists.txt | 1 +
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 1 +
mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt | 1 +
.../ArithToAMDGPU/16-bit-floats.mlir | 51 ++++++++
.../ArithToAMDGPU/8-bit-float-saturation.mlir | 2 +-
.../ArithToAMDGPU/8-bit-floats.mlir | 2 +-
mlir/test/Target/LLVMIR/rocdl.mlir | 6 +
12 files changed, 194 insertions(+), 11 deletions(-)
create mode 100644 mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 78c79c915e0607..28fdc234e5ef07 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -9,7 +9,9 @@
#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include <memory>
+#include <string>
namespace mlir {
@@ -26,7 +28,10 @@ namespace arith {
/// to the largest value of that type instead of being rewritten to Inf (aka
/// NaN).
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
- bool saturateFP8TruncF);
+ bool convertFP8Arithmetic,
+ bool saturateFP8Truncf,
+ bool allowPackedF16Rtz,
+ amdgpu::Chipset chipset);
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961c..24dc3b67db5a56 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -150,9 +150,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
let options = [
+ Option<"chipset", "chipset", "std::string",
+ /*default=*/"\"gfx000\"",
+ "Chipset that these operations will run on">,
Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
/*default=*/"false",
"Use saturating truncation for 8-bit float types">,
+ Option<"allowPackedF16Rtz", "allow-packed-f16-round-to-zero", "bool",
+ /*default=*/"false",
+ "Whether we should allow f32->f16 packed round-to-zero conversion">,
];
}
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index aa2b4543927a7f..d6fcf7329b6099 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -25,6 +25,7 @@ def AMDGPU_Dialect : Dialect {
let dependentDialects = [
+ "ROCDL::ROCDLDialect",
"arith::ArithDialect",
"gpu::GPUDialect"
];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a521..082148ddb13d6f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
let summary = "Vote across thread group";
let description = [{
- Ballot provides a bit mask containing the 1-bit predicate value from each lane.
+ Ballot provides a bit mask containing the 1-bit predicate value from each lane.
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];
@@ -554,6 +554,21 @@ def ROCDL_RawBufferAtomicUMinOp :
let hasCustomAssemblyFormat = 1;
}
+//===---------------------------------------------------------------------===//
+// 16-bit float intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtPkRtz:
+ ROCDL_IntrOp<"cvt.pkrtz", [], [], [Pure], 1>,
+ Arguments<(ins F32:$srcA, F32:$srcB)> {
+ let summary = "Convert two f32 input into a vector<2xf16>";
+ let description = [{
+ Convert two f32 values into a packed vector<2xf16>.
+ }];
+ let assemblyFormat = [{
+ attr-dict $srcA `,` $srcB `:` type($res)
+ }];
+}
+
//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b3798a3f7624b0..5c37ec536d8963 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -9,8 +9,11 @@
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
@@ -24,6 +27,7 @@ namespace mlir {
} // namespace mlir
using namespace mlir;
+using namespace mlir::amdgpu;
namespace {
struct ArithToAMDGPUConversionPass final
@@ -43,12 +47,25 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
bool saturateFP8 = false;
- TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
- : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
+ TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
+ Chipset chipset)
+ : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
+ chipset(chipset) {}
+ Chipset chipset;
LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
};
+
+struct TruncfToFloat16RewritePattern final
+ : public OpRewritePattern<arith::TruncFOp> {
+
+ using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+
+ LogicalResult match(arith::TruncFOp op) const override;
+ void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
+};
+
} // end namespace
static Value castF32To(Type elementType, Value f32, Location loc,
@@ -272,17 +289,96 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
rewriter.replaceOp(op, result);
}
+LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
+ Type outType = op.getOut().getType();
+ Type inputType = getElementTypeOrSelf(op.getIn());
+ if (auto outVecType = dyn_cast<VectorType>(outType)) {
+ if (outVecType.isScalable())
+ return failure();
+ if (outVecType.getShape().size() > 1)
+ // Multi-dimensional vectors are currently unsupported.
+ return failure();
+ outType = outVecType.getElementType();
+ }
+ return success(outType.isF16() && inputType.isF32());
+}
+
+void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ Value in = op.getIn();
+ Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+ VectorType truncResType = VectorType::get(2, outElemType);
+
+ // Handle the case where input type is not a vector type
+ if (!isa<VectorType>(in.getType())) {
+ auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+ Value asF16s =
+ rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
+ Value result = rewriter.create<vector::ExtractElementOp>(
+ loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+ return rewriter.replaceOp(op, result);
+ }
+ VectorType outType = cast<VectorType>(op.getOut().getType());
+ int64_t numElements = outType.getNumElements();
+ Value zero = rewriter.createOrFold<arith::ConstantOp>(
+ loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+
+ // Handle the vector case. We also handle the (uncommon) case where the vector
+ // length is odd
+ for (int64_t i = 0; i < numElements; i += 2) {
+ int64_t elemsThisOp = std::min(numElements, i + 2) - i;
+ Value thisResult = nullptr;
+ Value elemA = rewriter.create<vector::ExtractElementOp>(
+ loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+ Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+
+ if (elemsThisOp == 2) {
+ elemB = rewriter.create<vector::ExtractElementOp>(
+ loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+ }
+
+ thisResult =
+ rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
+ // Place back the truncated result into the possibly larger vector. If we
+ // are operating on a size 2 vector, these operations should be folded away
+ thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, thisResult, 0, elemsThisOp, 1);
+ result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
+ result, i, 1);
+ }
+ rewriter.replaceOp(op, result);
+}
+
void mlir::arith::populateArithToAMDGPUConversionPatterns(
- RewritePatternSet &patterns, bool saturateFP8TruncF) {
- patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
- patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
- saturateFP8TruncF);
+ RewritePatternSet &patterns, bool convertFP8Arithmetic,
+ bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
+
+ if (convertFP8Arithmetic) {
+ patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+ patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
+ saturateFP8Truncf, chipset);
+ }
+ if (allowPackedF16Rtz)
+ patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
}
void ArithToAMDGPUConversionPass::runOnOperation() {
Operation *op = getOperation();
+ MLIRContext *ctx = &getContext();
RewritePatternSet patterns(op->getContext());
- arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
+ FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
+ if (failed(maybeChipset)) {
+ emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
+ return signalPassFailure();
+ }
+
+ bool convertFP8Arithmetic =
+ (*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
+ arith::populateArithToAMDGPUConversionPatterns(
+ patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
+ *maybeChipset);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
index e2c951b0b34d8b..50be09ab5a7c5b 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
LINK_LIBS PUBLIC
MLIRAMDGPUDialect
+ MLIRAMDGPUUtils
MLIRArithDialect
MLIRArithUtils
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..0b1dd79ded3a71 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 0551d13b5a0cf0..78d78cf48a747c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRROCDLDialect
# Needed for GPU address space enum definition
MLIRGPUDialect
MLIRIR
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
new file mode 100644
index 00000000000000..121cae26748a82
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="allow-packed-f16-round-to-zero=true" | FileCheck %s
+
+// CHECK-LABEL: @scalar_trunc
+// CHECK-SAME: (%[[value:.*]]: f32)
+func.func @scalar_trunc(%v: f32) -> f16{
+ // CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
+ // CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
+ // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+ // CHECK: return %[[extract]] : f16
+ %w = arith.truncf %v : f32 to f16
+ return %w : f16
+}
+
+// CHECK-LABEL: @vector_trunc
+// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
+ // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+ // CHECK: return %[[ret]]
+ %w = arith.truncf %v : vector<2xf32> to vector<2xf16>
+ return %w : vector<2xf16>
+}
+
+// CHECK-LABEL: @vector_trunc_long
+// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
+func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
+ // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
+ // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+ // CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+ // CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
+ // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+ // CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
+ // CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
+ // CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
+ // CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
+ // CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
+ // CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
+ // CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+ // CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>
+ // CHECK: return %[[out4]]
+ %w = arith.truncf %v : vector<9xf32> to vector<9xf16>
+ return %w : vector<9xf16>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
index c7f39440a349b0..cd921da2294e13 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt --split-input-file %s \
-// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
+// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx940 saturate-fp8-truncf=true}))' \
// RUN: | FileCheck %s
// CHECK-LABEL: func.func @scalar_trunc
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 26a222a4a788e5..bd90facb615440 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx940" | FileCheck %s
// CHECK-LABEL: func.func @scalar_ext
// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 78c3987fab648e..d04978ff6deeb7 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -516,6 +516,12 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
llvm.return %source5 : i32
}
+llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
+ // CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})
+ %source = rocdl.cvt.pkrtz %sourceA, %sourceB : vector<2xf16>
+ llvm.return %source : vector<2xf16>
+}
+
// CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
// CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
// CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"
>From 42eecb561570e1ff1dd5d042a7db338f0bc3e55b Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Fri, 23 Aug 2024 13:56:40 +0100
Subject: [PATCH 2/2] Address review feedback
---
.../Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp | 17 +++++++++++++----
1 file changed, 13 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 5c37ec536d8963..d36583c8118ff4 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -295,9 +295,6 @@ LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
if (auto outVecType = dyn_cast<VectorType>(outType)) {
if (outVecType.isScalable())
return failure();
- if (outVecType.getShape().size() > 1)
- // Multi-dimensional vectors are currently unsupported.
- return failure();
outType = outVecType.getElementType();
}
return success(outType.isF16() && inputType.isF32());
@@ -309,9 +306,10 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
VectorType truncResType = VectorType::get(2, outElemType);
+ auto inVectorTy = dyn_cast<VectorType>(in.getType());
// Handle the case where input type is not a vector type
- if (!isa<VectorType>(in.getType())) {
+ if (!inVectorTy) {
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
Value asF16s =
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
@@ -325,6 +323,12 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+ if (inVectorTy.getRank() > 1) {
+ inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
+ inVectorTy.getElementType());
+ in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+ }
+
// Handle the vector case. We also handle the (uncommon) case where the vector
// length is odd
for (int64_t i = 0; i < numElements; i += 2) {
@@ -348,6 +352,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
result, i, 1);
}
+
+ if (inVectorTy.getRank() != outType.getRank()) {
+ result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+ }
+
rewriter.replaceOp(op, result);
}
More information about the Mlir-commits
mailing list