[Mlir-commits] [mlir] [MLIR] Fix rewrite of ops with vector operands to LLVM on GPU (PR #127844)
Benoit Jacob
llvmlistbot at llvm.org
Wed Feb 19 11:50:38 PST 2025
https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/127844
>From d5f407ebaecdfaee6d4e7bc15bc9f18c44f66a7c Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 19 Feb 2025 09:59:23 -0600
Subject: [PATCH 1/2] llvm-array-unroll
Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 51 ++++++++++++-------
.../lib/Conversion/GPUCommon/GPUOpsLowering.h | 5 +-
.../Conversion/MathToROCDL/math-to-rocdl.mlir | 41 +++++++++++++++
3 files changed, 78 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index cfa434699cdef..c3b3a78abe7f7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
#include "GPUOpsLowering.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -586,22 +587,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
return success();
}
-/// Unrolls op if it's operating on vectors.
-LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
- ConversionPatternRewriter &rewriter,
- const LLVMTypeConverter &converter) {
+/// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
+/// Used either directly (for ops on 1D vectors) or as the callback passed to
+/// detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
+static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
+ Type llvm1DVectorTy,
+ ConversionPatternRewriter &rewriter,
+ const LLVMTypeConverter &converter) {
TypeRange operandTypes(operands);
- if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
- return rewriter.notifyMatchFailure(op, "expected vector operand");
- }
- if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
- return rewriter.notifyMatchFailure(op, "expected no region/successor");
- if (op->getNumResults() != 1)
- return rewriter.notifyMatchFailure(op, "expected single result");
- VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
- if (!vectorType)
- return rewriter.notifyMatchFailure(op, "expected vector result");
-
+ VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
Location loc = op->getLoc();
Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
Type indexType = converter.convertType(rewriter.getIndexType());
@@ -621,9 +615,32 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
result = rewriter.create<LLVM::InsertElementOp>(
loc, result, scalarOp->getResult(0), index);
}
+ return result;
+}
- rewriter.replaceOp(op, result);
- return success();
+/// Unrolls op to array/vector elements.
+LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ const LLVMTypeConverter &converter) {
+ TypeRange operandTypes(operands);
+ if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
+ VectorType vectorType = cast<VectorType>(op->getResultTypes()[0]);
+ rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
+ rewriter, converter));
+ return success();
+ }
+
+ if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
+ return LLVM::detail::handleMultidimensionalVectors(
+ op, operands, converter,
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+ return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter,
+ converter);
+ },
+ rewriter);
+ }
+
+ return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll");
}
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index e73a74845d2b6..bd2fd020f684b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -172,13 +172,13 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
};
namespace impl {
-/// Unrolls op if it's operating on vectors.
+/// Unrolls op to array/vector elements.
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter);
} // namespace impl
-/// Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
+/// Unrolls SourceOp to array/vector elements.
template <typename SourceOp>
struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
@@ -191,6 +191,7 @@ struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
*this->getTypeConverter());
}
};
+
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index e4b2f01d6544a..b6493ca9b32c3 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -513,3 +513,44 @@ module {
"test.possible_terminator"() : () -> ()
}) : () -> ()
}
+
+// -----
+
+module @test_module {
+ // CHECK-LABEL: func @math_sin_vector_1d
+ func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ %result = math.sin %arg : vector<4xf16>
+ func.return %result : vector<4xf16>
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK-LABEL: func @math_sin_vector_2d
+ func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
+ // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ %result = math.sin %arg : vector<2x2xf16>
+ func.return %result : vector<2x2xf16>
+ }
+}
>From 524888966a756eecef92189ff228b758e56bf8fa Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 19 Feb 2025 13:50:21 -0600
Subject: [PATCH 2/2] review-comment
Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index b6493ca9b32c3..9448304f11dbd 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -517,15 +517,20 @@ module {
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
// CHECK-LABEL: func @math_sin_vector_1d
func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
%result = math.sin %arg : vector<4xf16>
func.return %result : vector<4xf16>
@@ -535,19 +540,24 @@ module @test_module {
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
// CHECK-LABEL: func @math_sin_vector_2d
func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
// CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
%result = math.sin %arg : vector<2x2xf16>
More information about the Mlir-commits
mailing list