[Mlir-commits] [mlir] [mlir][arith.constant]Fix element type of the dense attributes in target attributes to be consistent with result type in LLVM::detail::oneToOneRewrite() (PR #149787)
Mengmeng Sun
llvmlistbot at llvm.org
Tue Jul 29 23:56:31 PDT 2025
https://github.com/MengmSun updated https://github.com/llvm/llvm-project/pull/149787
>From d4e6a3ff33dc82e574f92b01e7613a007f63a7d9 Mon Sep 17 00:00:00 2001
From: MengmengSun <mengmengs at nvidia.com>
Date: Mon, 21 Jul 2025 02:01:35 -0700
Subject: [PATCH 1/2] Fix element type of target attributes in oneToOneRewrite
when converting to llvm
---
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 27 ++++++++++++++++++-
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 17 +++++++++++-
2 files changed, 42 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index c5f72f7e10b8c..329703e4f054d 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -331,10 +331,35 @@ LogicalResult LLVM::detail::oneToOneRewrite(
return failure();
}
+ // If the targetAttrs contains DenseElementsAttr,
+ // and the element type of the DenseElementsAttr and result type is
+ // inconsistent after the conversion of result types, we need to convert the
+ // element type of the DenseElementsAttr to the target type by creating a new
+ // DenseElementsAttr with the converted element type, and use the new
+ // DenseElementsAttr to replace the old one in the targetAttrs
+ SmallVector<NamedAttribute> convertedAttrs;
+ for (auto attr : targetAttrs) {
+ if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
+ VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType());
+ if (vectorType) {
+ auto convertedElementType =
+ typeConverter.convertType(vectorType.getElementType());
+ VectorType convertedVectorType =
+ VectorType::get(vectorType.getShape(), convertedElementType,
+ vectorType.getScalableDims());
+ convertedAttrs.emplace_back(
+ attr.getName(), DenseElementsAttr::getFromRawBuffer(
+ convertedVectorType, denseAttr.getRawData()));
+ }
+ } else {
+ convertedAttrs.push_back(attr);
+ }
+ }
+
// Create the operation through state since we don't know its C++ type.
Operation *newOp =
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
- resultTypes, targetAttrs);
+ resultTypes, convertedAttrs);
setNativeProperties(newOp, overflowFlags);
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 83bdbe1f67118..299cc32351bdb 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -428,7 +428,7 @@ func.func @fcmp(f32, f32) -> () {
// CHECK-LABEL: @index_vector
func.func @index_vector(%arg0: vector<4xindex>) {
- // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64>
+ // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xi64>) : vector<4xi64>
%0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK: %[[V:.*]] = llvm.add %{{.*}}, %[[CST]] : vector<4xi64>
%1 = arith.addi %arg0, %0 : vector<4xindex>
@@ -437,6 +437,21 @@ func.func @index_vector(%arg0: vector<4xindex>) {
// -----
+// CHECK-LABEL: @f8E4M3FN_vector
+func.func @f8E4M3FN_vector() -> vector<4xf8E4M3FN> {
+ // CHECK: %[[CST0:.*]] = llvm.mlir.constant(dense<0> : vector<4xi8>) : vector<4xi8>
+ %0 = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FN>
+ // CHECK: %[[CST1:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
+ %1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf8E4M3FN>
+ // CHECK: %[[V:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
+ %2 = arith.addf %0, %1 : vector<4xf8E4M3FN>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V]] : vector<4xi8> to vector<4xf8E4M3FN>
+ // CHECK-NEXT: return %[[RES]] : vector<4xf8E4M3FN>
+ func.return %2 : vector<4xf8E4M3FN>
+}
+
+// -----
+
// CHECK-LABEL: @bitcast_1d
func.func @bitcast_1d(%arg0: vector<2xf32>) {
// CHECK: llvm.bitcast %{{.*}} : vector<2xf32> to vector<2xi32>
>From 38b4227db3dc450caf396c7429c67b677112d3df Mon Sep 17 00:00:00 2001
From: MengmengSun <mengmengs at nvidia.com>
Date: Tue, 29 Jul 2025 23:53:18 -0700
Subject: [PATCH 2/2] Spread elements type conversion to all valid type attr
---
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 72 +++++++++++++++++-----
1 file changed, 56 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 8546989fe8e2e..2c02db4b0da16 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -330,25 +330,65 @@ LogicalResult LLVM::detail::oneToOneRewrite(
return failure();
}
- // If the targetAttrs contains DenseElementsAttr,
- // and the element type of the DenseElementsAttr and result type is
- // inconsistent after the conversion of result types, we need to convert the
- // element type of the DenseElementsAttr to the target type by creating a new
- // DenseElementsAttr with the converted element type, and use the new
- // DenseElementsAttr to replace the old one in the targetAttrs
+ // Convert attribute element types to match the converted result types.
+ // This ensures that attributes like
+ // dense<0.0> : vector<4xf8E4M3FN> become
+ // dense<0> : vector<4xi8>
+ // when the result type is converted to i8.
SmallVector<NamedAttribute> convertedAttrs;
for (auto attr : targetAttrs) {
- if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
- VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType());
- if (vectorType) {
- auto convertedElementType =
- typeConverter.convertType(vectorType.getElementType());
- VectorType convertedVectorType =
- VectorType::get(vectorType.getShape(), convertedElementType,
- vectorType.getScalableDims());
+ if (auto floatAttr = dyn_cast<FloatAttr>(attr.getValue())) {
+ auto convertedElementType =
+ typeConverter.convertType(floatAttr.getType());
+ if (convertedElementType != floatAttr.getType()) {
+ // Currently, only 1-byte or sub-byte float types will be converted and
+ // converted to integer types.
+ convertedAttrs.emplace_back(
+ attr.getName(),
+ IntegerAttr::get(convertedElementType,
+ floatAttr.getValue().bitcastToAPInt()));
+ } else {
+ convertedAttrs.emplace_back(attr);
+ }
+ } else if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
+ auto convertedElementType = typeConverter.convertType(intAttr.getType());
+ if (convertedElementType != intAttr.getType()) {
convertedAttrs.emplace_back(
- attr.getName(), DenseElementsAttr::getFromRawBuffer(
- convertedVectorType, denseAttr.getRawData()));
+ attr.getName(),
+ IntegerAttr::get(convertedElementType, intAttr.getValue()));
+ } else {
+ convertedAttrs.emplace_back(attr);
+ }
+ } else if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
+ if (auto shapedType = dyn_cast<ShapedType>(denseAttr.getType())) {
+ auto convertedElementType =
+ typeConverter.convertType(shapedType.getElementType());
+ if (convertedElementType != shapedType.getElementType()) {
+ ShapedType convertedShapedType =
+ shapedType.cloneWith(std::nullopt, convertedElementType);
+ convertedAttrs.emplace_back(
+ attr.getName(), DenseElementsAttr::getFromRawBuffer(
+ convertedShapedType, denseAttr.getRawData()));
+ } else {
+ convertedAttrs.emplace_back(attr);
+ }
+ }
+ } else if (auto sparseAttr =
+ dyn_cast<SparseElementsAttr>(attr.getValue())) {
+ if (auto shapedType = dyn_cast<ShapedType>(sparseAttr.getType())) {
+ auto convertedElementType =
+ typeConverter.convertType(shapedType.getElementType());
+ if (convertedElementType != shapedType.getElementType()) {
+ ShapedType convertedShapedType =
+ shapedType.cloneWith(std::nullopt, convertedElementType);
+ convertedAttrs.emplace_back(
+ attr.getName(),
+ SparseElementsAttr::get(
+ convertedShapedType, sparseAttr.getIndices(),
+ sparseAttr.getValues().bitcast(convertedElementType)));
+ } else {
+ convertedAttrs.emplace_back(attr);
+ }
}
} else {
convertedAttrs.push_back(attr);
More information about the Mlir-commits
mailing list