[Mlir-commits] [mlir] af6e3c0 - [mlir][math] Fix intrinsic conversions to LLVM for 0D-vector types (#141020)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 2 04:27:48 PDT 2025
Author: Artem Gindinson
Date: 2025-06-02T12:27:44+01:00
New Revision: af6e3c045b07ebc7ce09318c90048f407a15b391
URL: https://github.com/llvm/llvm-project/commit/af6e3c045b07ebc7ce09318c90048f407a15b391
DIFF: https://github.com/llvm/llvm-project/commit/af6e3c045b07ebc7ce09318c90048f407a15b391.diff
LOG: [mlir][math] Fix intrinsic conversions to LLVM for 0D-vector types (#141020)
`vector<t>` types are not compatible with the LLVM type system – with
the current approach employed within `LLVMTypeConverter`, they must be
explicitly converted into `vector<1xt>` when lowering. Employ this rule
within the conversion patterns for intrinsics that are handled directly
within `MathToLLVM`: `math.ctlz` `.cttz`, `.absi`, `.expm1`, `.log1p`,
`.rsqrt`, `.isnan`, `.isfinite`.
This change does not cover/test patterns that are based off
`VectorConvertToLLVMPattern` template from `LLVMCommon/VectorPattern.h`.
---------
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
Added:
Modified:
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index b42bb773f53ee..f4d69ce8235bb 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -73,6 +73,8 @@ using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
using ATan2OpLowering =
ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
+// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
+// may be better to separate the patterns.
template <typename MathOp, typename LLVMOp>
struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
@@ -81,26 +83,29 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
LogicalResult
matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *this->getTypeConverter();
auto operandType = adaptor.getOperand().getType();
-
- if (!operandType || !LLVM::isCompatibleType(operandType))
+ auto llvmOperandType = typeConverter.convertType(operandType);
+ if (!llvmOperandType)
return failure();
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
+ auto llvmResultType = typeConverter.convertType(resultType);
+ if (!llvmResultType)
+ return failure();
- if (!isa<LLVM::LLVMArrayType>(operandType)) {
- rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
- false);
+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
+ rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
+ adaptor.getOperand(), false);
return success();
}
- auto vectorType = dyn_cast<VectorType>(resultType);
- if (!vectorType)
+ if (!isa<VectorType>(llvmResultType))
return failure();
return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
+ op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
false);
@@ -123,40 +128,42 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
LogicalResult
matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *this->getTypeConverter();
auto operandType = adaptor.getOperand().getType();
-
- if (!operandType || !LLVM::isCompatibleType(operandType))
+ auto llvmOperandType = typeConverter.convertType(operandType);
+ if (!llvmOperandType)
return failure();
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
- auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
+ auto floatType = cast<FloatType>(
+ typeConverter.convertType(getElementTypeOrSelf(resultType)));
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
- if (!isa<LLVM::LLVMArrayType>(operandType)) {
+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one;
- if (LLVM::isCompatibleVectorType(operandType)) {
+ if (LLVM::isCompatibleVectorType(llvmOperandType)) {
one = rewriter.create<LLVM::ConstantOp>(
- loc, operandType,
- SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
+ loc, llvmOperandType,
+ SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
+ floatOne));
} else {
- one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+ one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
}
auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
expAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
- op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
+ op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
return success();
}
- auto vectorType = dyn_cast<VectorType>(resultType);
- if (!vectorType)
+ if (!isa<VectorType>(resultType))
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+ op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
@@ -181,41 +188,43 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
LogicalResult
matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *this->getTypeConverter();
auto operandType = adaptor.getOperand().getType();
-
- if (!operandType || !LLVM::isCompatibleType(operandType))
+ auto llvmOperandType = typeConverter.convertType(operandType);
+ if (!llvmOperandType)
return rewriter.notifyMatchFailure(op, "unsupported operand type");
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
- auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
+ auto floatType = cast<FloatType>(
+ typeConverter.convertType(getElementTypeOrSelf(resultType)));
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
- if (!isa<LLVM::LLVMArrayType>(operandType)) {
+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one =
- LLVM::isCompatibleVectorType(operandType)
+ isa<VectorType>(llvmOperandType)
? rewriter.create<LLVM::ConstantOp>(
- loc, operandType,
- SplatElementsAttr::get(cast<ShapedType>(resultType),
+ loc, llvmOperandType,
+ SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne))
- : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+ : rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType,
+ floatOne);
auto add = rewriter.create<LLVM::FAddOp>(
- loc, operandType, ValueRange{one, adaptor.getOperand()},
+ loc, llvmOperandType, ValueRange{one, adaptor.getOperand()},
addAttrs.getAttrs());
- rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
- logAttrs.getAttrs());
+ rewriter.replaceOpWithNewOp<LLVM::LogOp>(
+ op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
return success();
}
- auto vectorType = dyn_cast<VectorType>(resultType);
- if (!vectorType)
+ if (!isa<VectorType>(resultType))
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+ op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
@@ -241,40 +250,42 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
LogicalResult
matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *this->getTypeConverter();
auto operandType = adaptor.getOperand().getType();
-
- if (!operandType || !LLVM::isCompatibleType(operandType))
+ auto llvmOperandType = typeConverter.convertType(operandType);
+ if (!llvmOperandType)
return failure();
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
- auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
+ auto floatType = cast<FloatType>(
+ typeConverter.convertType(getElementTypeOrSelf(resultType)));
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
- if (!isa<LLVM::LLVMArrayType>(operandType)) {
+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one;
- if (LLVM::isCompatibleVectorType(operandType)) {
+ if (isa<VectorType>(llvmOperandType)) {
one = rewriter.create<LLVM::ConstantOp>(
- loc, operandType,
- SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
+ loc, llvmOperandType,
+ SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
+ floatOne));
} else {
- one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+ one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
}
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
sqrtAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
- op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
+ op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
return success();
}
- auto vectorType = dyn_cast<VectorType>(resultType);
- if (!vectorType)
+ if (!isa<VectorType>(resultType))
return failure();
return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+ op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
@@ -298,13 +309,15 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
LogicalResult
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto operandType = adaptor.getOperand().getType();
-
- if (!operandType || !LLVM::isCompatibleType(operandType))
+ const auto &typeConverter = *this->getTypeConverter();
+ auto operandType =
+ typeConverter.convertType(adaptor.getOperand().getType());
+ auto resultType = typeConverter.convertType(op.getResult().getType());
+ if (!operandType || !resultType)
return failure();
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
- op, op.getType(), adaptor.getOperand(), llvm::fcNan);
+ op, resultType, adaptor.getOperand(), llvm::fcNan);
return success();
}
};
@@ -315,13 +328,15 @@ struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
LogicalResult
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto operandType = adaptor.getOperand().getType();
-
- if (!operandType || !LLVM::isCompatibleType(operandType))
+ const auto &typeConverter = *this->getTypeConverter();
+ auto operandType =
+ typeConverter.convertType(adaptor.getOperand().getType());
+ auto resultType = typeConverter.convertType(op.getResult().getType());
+ if (!operandType || !resultType)
return failure();
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
- op, op.getType(), adaptor.getOperand(), llvm::fcFinite);
+ op, resultType, adaptor.getOperand(), llvm::fcFinite);
return success();
}
};
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 537fb967ef0e1..92904082a6f46 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -19,6 +19,8 @@ func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) {
// -----
+// CHECK-LABEL: func @absi(
+// CHECK-SAME: i32
func.func @absi(%arg0: i32) -> i32 {
// CHECK: = "llvm.intr.abs"(%{{.*}}) <{is_int_min_poison = false}> : (i32) -> i32
%0 = math.absi %arg0 : i32
@@ -27,6 +29,17 @@ func.func @absi(%arg0: i32) -> i32 {
// -----
+// CHECK-LABEL: func @absi_0dvector(
+// CHECK-SAME: vector<i32>
+func.func @absi_0dvector(%arg0 : vector<i32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+ // CHECK: "llvm.intr.abs"(%[[CAST]]) <{is_int_min_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+ %0 = math.absi %arg0 : vector<i32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @log1p(
// CHECK-SAME: f32
func.func @log1p(%arg0 : f32) {
@@ -89,6 +102,19 @@ func.func @log1p_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
// -----
+// CHECK-LABEL: func @log1p_0dvector(
+// CHECK-SAME: vector<f32>
+func.func @log1p_0dvector(%arg0 : vector<f32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<1xf32>) : vector<1xf32>
+ // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[CAST]] : vector<1xf32>
+ // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) : (vector<1xf32>) -> vector<1xf32>
+ %0 = math.log1p %arg0 : vector<f32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @expm1(
// CHECK-SAME: f32
func.func @expm1(%arg0 : f32) {
@@ -149,6 +175,19 @@ func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) {
// -----
+// CHECK-LABEL: func @expm1_0dvector(
+// CHECK-SAME: vector<f32>
+func.func @expm1_0dvector(%arg0 : vector<f32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<1xf32>) : vector<1xf32>
+ // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[CAST]]) : (vector<1xf32>) -> vector<1xf32>
+ // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<1xf32>
+ %0 = math.expm1 %arg0 : vector<f32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt(
// CHECK-SAME: f32
func.func @rsqrt(%arg0 : f32) {
@@ -161,6 +200,19 @@ func.func @rsqrt(%arg0 : f32) {
// -----
+// CHECK-LABEL: func @rsqrt_0dvector(
+// CHECK-SAME: vector<f32>
+func.func @rsqrt_0dvector(%arg0 : vector<f32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<1xf32>) : vector<1xf32>
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[CAST]]) : (vector<1xf32>) -> vector<1xf32>
+ // CHECK: %[[SUB:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<1xf32>
+ %0 = math.rsqrt %arg0 : vector<f32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @trigonometrics
// CHECK-SAME: [[ARG0:%.+]]: f32
func.func @trigonometrics(%arg0: f32) {
@@ -279,6 +331,15 @@ func.func @ctlz(%arg0 : i32) {
func.return
}
+// CHECK-LABEL: func @ctlz_0dvector(
+// CHECK-SAME: vector<i32>
+func.func @ctlz_0dvector(%arg0 : vector<i32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+ // CHECK: "llvm.intr.ctlz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+ %0 = math.ctlz %arg0 : vector<i32>
+ func.return
+}
+
// -----
// CHECK-LABEL: func @cttz(
@@ -291,6 +352,17 @@ func.func @cttz(%arg0 : i32) {
// -----
+// CHECK-LABEL: func @cttz_0dvector(
+// CHECK-SAME: vector<i32>
+func.func @cttz_0dvector(%arg0 : vector<i32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+ // CHECK: "llvm.intr.cttz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+ %0 = math.cttz %arg0 : vector<i32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @cttz_vec(
// CHECK-SAME: i32
func.func @cttz_vec(%arg0 : vector<4xi32>) {
@@ -351,6 +423,17 @@ func.func @isnan_double(%arg0 : f64) {
// -----
+// CHECK-LABEL: func @isnan_0dvector(
+// CHECK-SAME: vector<f32>
+func.func @isnan_0dvector(%arg0 : vector<f32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
+ // CHECK: "llvm.intr.is.fpclass"(%[[CAST]]) <{bit = 3 : i32}> : (vector<1xf32>) -> vector<1xi1>
+ %0 = math.isnan %arg0 : vector<f32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @isfinite_double(
// CHECK-SAME: f64
func.func @isfinite_double(%arg0 : f64) {
@@ -361,6 +444,17 @@ func.func @isfinite_double(%arg0 : f64) {
// -----
+// CHECK-LABEL: func @isfinite_0dvector(
+// CHECK-SAME: vector<f32>
+func.func @isfinite_0dvector(%arg0 : vector<f32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
+ // CHECK: "llvm.intr.is.fpclass"(%[[CAST]]) <{bit = 504 : i32}> : (vector<1xf32>) -> vector<1xi1>
+ %0 = math.isfinite %arg0 : vector<f32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_double(
// CHECK-SAME: f64
func.func @rsqrt_double(%arg0 : f64) {
More information about the Mlir-commits
mailing list