[Mlir-commits] [mlir] 4a411eb - [MLIR] Fix rewrite of ops with vector operands to LLVM on GPU (#127844)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 19 11:52:05 PST 2025
Author: Benoit Jacob
Date: 2025-02-19T14:52:02-05:00
New Revision: 4a411eb4ee673e2687d38fda16d6db6b907f37d2
URL: https://github.com/llvm/llvm-project/commit/4a411eb4ee673e2687d38fda16d6db6b907f37d2
DIFF: https://github.com/llvm/llvm-project/commit/4a411eb4ee673e2687d38fda16d6db6b907f37d2.diff
LOG: [MLIR] Fix rewrite of ops with vector operands to LLVM on GPU (#127844)
There was a discrepancy between the type-converter and rewrite-pattern
parts of conversion to LLVM used in various GPU targets, at least ROCDL
and NVVM:
- The TypeConverter part was handling vectors of arbitrary rank,
converting them to nests of `!llvm.array< ... >` with a vector at the
inner-most dimension:
https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp#L629-L655
- The rewrite pattern part was not handling `llvm.array`:
https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp#L594-L596
That led to conversion failures when lowering `math` dialect ops on
rank-2 vectors, as in the testcase being added in this PR.
This PR fixes this by reusing a shared utility already used in other
conversions to LLVM:
https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp#L80-L104
---------
Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
Added:
Modified:
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index cfa434699cdef..c3b3a78abe7f7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
#include "GPUOpsLowering.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -586,22 +587,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
return success();
}
-/// Unrolls op if it's operating on vectors.
-LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
- ConversionPatternRewriter &rewriter,
- const LLVMTypeConverter &converter) {
+/// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
+/// Used either directly (for ops on 1D vectors) or as the callback passed to
+/// detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
+static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
+ Type llvm1DVectorTy,
+ ConversionPatternRewriter &rewriter,
+ const LLVMTypeConverter &converter) {
TypeRange operandTypes(operands);
- if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
- return rewriter.notifyMatchFailure(op, "expected vector operand");
- }
- if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
- return rewriter.notifyMatchFailure(op, "expected no region/successor");
- if (op->getNumResults() != 1)
- return rewriter.notifyMatchFailure(op, "expected single result");
- VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
- if (!vectorType)
- return rewriter.notifyMatchFailure(op, "expected vector result");
-
+ VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
Location loc = op->getLoc();
Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
Type indexType = converter.convertType(rewriter.getIndexType());
@@ -621,9 +615,32 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
result = rewriter.create<LLVM::InsertElementOp>(
loc, result, scalarOp->getResult(0), index);
}
+ return result;
+}
- rewriter.replaceOp(op, result);
- return success();
+/// Unrolls op to array/vector elements.
+LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ const LLVMTypeConverter &converter) {
+ TypeRange operandTypes(operands);
+ if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
+ VectorType vectorType = cast<VectorType>(op->getResultTypes()[0]);
+ rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
+ rewriter, converter));
+ return success();
+ }
+
+ if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
+ return LLVM::detail::handleMultidimensionalVectors(
+ op, operands, converter,
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+ return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter,
+ converter);
+ },
+ rewriter);
+ }
+
+ return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll");
}
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index e73a74845d2b6..bd2fd020f684b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -172,13 +172,13 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
};
namespace impl {
-/// Unrolls op if it's operating on vectors.
+/// Unrolls op to array/vector elements.
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter);
} // namespace impl
-/// Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
+/// Unrolls SourceOp to array/vector elements.
template <typename SourceOp>
struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
@@ -191,6 +191,7 @@ struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
*this->getTypeConverter());
}
};
+
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index e4b2f01d6544a..9448304f11dbd 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -513,3 +513,54 @@ module {
"test.possible_terminator"() : () -> ()
}) : () -> ()
}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
+ // CHECK-LABEL: func @math_sin_vector_1d
+ func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+ %result = math.sin %arg : vector<4xf16>
+ func.return %result : vector<4xf16>
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
+ // CHECK-LABEL: func @math_sin_vector_2d
+ func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
+ // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ %result = math.sin %arg : vector<2x2xf16>
+ func.return %result : vector<2x2xf16>
+ }
+}
More information about the Mlir-commits
mailing list