[Mlir-commits] [mlir] [mlir][arith.constant]Fix element type of the dense attributes in target attributes to be consistent with result type in LLVM::detail::oneToOneRewrite() (PR #149787)

Mengmeng Sun llvmlistbot at llvm.org
Tue Jul 29 23:56:31 PDT 2025


https://github.com/MengmSun updated https://github.com/llvm/llvm-project/pull/149787

>From d4e6a3ff33dc82e574f92b01e7613a007f63a7d9 Mon Sep 17 00:00:00 2001
From: MengmengSun <mengmengs at nvidia.com>
Date: Mon, 21 Jul 2025 02:01:35 -0700
Subject: [PATCH 1/2] Fix element type of target attributes in oneToOneRewrite
 when converting to llvm

---
 mlir/lib/Conversion/LLVMCommon/Pattern.cpp    | 27 ++++++++++++++++++-
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 17 +++++++++++-
 2 files changed, 42 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index c5f72f7e10b8c..329703e4f054d 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -331,10 +331,35 @@ LogicalResult LLVM::detail::oneToOneRewrite(
       return failure();
   }
 
+  // If the targetAttrs contains DenseElementsAttr,
+  // and the element type of the DenseElementsAttr and result type is
+  // inconsistent after the conversion of result types, we need to convert the
+  // element type of the DenseElementsAttr to the target type by creating a new
+  // DenseElementsAttr with the converted element type, and use the new
+  // DenseElementsAttr to replace the old one in the targetAttrs
+  SmallVector<NamedAttribute> convertedAttrs;
+  for (auto attr : targetAttrs) {
+    if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
+      VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType());
+      if (vectorType) {
+        auto convertedElementType =
+            typeConverter.convertType(vectorType.getElementType());
+        VectorType convertedVectorType =
+            VectorType::get(vectorType.getShape(), convertedElementType,
+                            vectorType.getScalableDims());
+        convertedAttrs.emplace_back(
+            attr.getName(), DenseElementsAttr::getFromRawBuffer(
+                                convertedVectorType, denseAttr.getRawData()));
+      }
+    } else {
+      convertedAttrs.push_back(attr);
+    }
+  }
+
   // Create the operation through state since we don't know its C++ type.
   Operation *newOp =
       rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
-                      resultTypes, targetAttrs);
+                      resultTypes, convertedAttrs);
 
   setNativeProperties(newOp, overflowFlags);
 
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 83bdbe1f67118..299cc32351bdb 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -428,7 +428,7 @@ func.func @fcmp(f32, f32) -> () {
 
 // CHECK-LABEL: @index_vector
 func.func @index_vector(%arg0: vector<4xindex>) {
-  // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64>
+  // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xi64>) : vector<4xi64>
   %0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
   // CHECK: %[[V:.*]] = llvm.add %{{.*}}, %[[CST]] : vector<4xi64>
   %1 = arith.addi %arg0, %0 : vector<4xindex>
@@ -437,6 +437,21 @@ func.func @index_vector(%arg0: vector<4xindex>) {
 
 // -----
 
+// CHECK-LABEL: @f8E4M3FN_vector
+func.func @f8E4M3FN_vector() -> vector<4xf8E4M3FN> {
+  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(dense<0> : vector<4xi8>) : vector<4xi8>
+  %0 = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FN>
+  // CHECK: %[[CST1:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
+  %1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf8E4M3FN>
+  // CHECK: %[[V:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
+  %2 = arith.addf %0, %1 : vector<4xf8E4M3FN>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V]] : vector<4xi8> to vector<4xf8E4M3FN>
+  // CHECK-NEXT: return %[[RES]] : vector<4xf8E4M3FN>
+  func.return %2 : vector<4xf8E4M3FN>
+}
+
+// -----
+
 // CHECK-LABEL: @bitcast_1d
 func.func @bitcast_1d(%arg0: vector<2xf32>) {
   // CHECK: llvm.bitcast %{{.*}} : vector<2xf32> to vector<2xi32>

>From 38b4227db3dc450caf396c7429c67b677112d3df Mon Sep 17 00:00:00 2001
From: MengmengSun <mengmengs at nvidia.com>
Date: Tue, 29 Jul 2025 23:53:18 -0700
Subject: [PATCH 2/2] Spread elements type conversion to all valid type attr

---
 mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 72 +++++++++++++++++-----
 1 file changed, 56 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 8546989fe8e2e..2c02db4b0da16 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -330,25 +330,65 @@ LogicalResult LLVM::detail::oneToOneRewrite(
       return failure();
   }
 
-  // If the targetAttrs contains DenseElementsAttr,
-  // and the element type of the DenseElementsAttr and result type is
-  // inconsistent after the conversion of result types, we need to convert the
-  // element type of the DenseElementsAttr to the target type by creating a new
-  // DenseElementsAttr with the converted element type, and use the new
-  // DenseElementsAttr to replace the old one in the targetAttrs
+  // Convert attribute element types to match the converted result types.
+  // This ensures that attributes like
+  // dense<0.0> : vector<4xf8E4M3FN> become
+  // dense<0> : vector<4xi8>
+  // when the result type is converted to i8.
   SmallVector<NamedAttribute> convertedAttrs;
   for (auto attr : targetAttrs) {
-    if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
-      VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType());
-      if (vectorType) {
-        auto convertedElementType =
-            typeConverter.convertType(vectorType.getElementType());
-        VectorType convertedVectorType =
-            VectorType::get(vectorType.getShape(), convertedElementType,
-                            vectorType.getScalableDims());
+    if (auto floatAttr = dyn_cast<FloatAttr>(attr.getValue())) {
+      auto convertedElementType =
+          typeConverter.convertType(floatAttr.getType());
+      if (convertedElementType != floatAttr.getType()) {
+        // Currently, only 1-byte or sub-byte float types will be converted and
+        // converted to integer types.
+        convertedAttrs.emplace_back(
+            attr.getName(),
+            IntegerAttr::get(convertedElementType,
+                             floatAttr.getValue().bitcastToAPInt()));
+      } else {
+        convertedAttrs.emplace_back(attr);
+      }
+    } else if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
+      auto convertedElementType = typeConverter.convertType(intAttr.getType());
+      if (convertedElementType != intAttr.getType()) {
         convertedAttrs.emplace_back(
-            attr.getName(), DenseElementsAttr::getFromRawBuffer(
-                                convertedVectorType, denseAttr.getRawData()));
+            attr.getName(),
+            IntegerAttr::get(convertedElementType, intAttr.getValue()));
+      } else {
+        convertedAttrs.emplace_back(attr);
+      }
+    } else if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
+      if (auto shapedType = dyn_cast<ShapedType>(denseAttr.getType())) {
+        auto convertedElementType =
+            typeConverter.convertType(shapedType.getElementType());
+        if (convertedElementType != shapedType.getElementType()) {
+          ShapedType convertedShapedType =
+              shapedType.cloneWith(std::nullopt, convertedElementType);
+          convertedAttrs.emplace_back(
+              attr.getName(), DenseElementsAttr::getFromRawBuffer(
+                                  convertedShapedType, denseAttr.getRawData()));
+        } else {
+          convertedAttrs.emplace_back(attr);
+        }
+      }
+    } else if (auto sparseAttr =
+                   dyn_cast<SparseElementsAttr>(attr.getValue())) {
+      if (auto shapedType = dyn_cast<ShapedType>(sparseAttr.getType())) {
+        auto convertedElementType =
+            typeConverter.convertType(shapedType.getElementType());
+        if (convertedElementType != shapedType.getElementType()) {
+          ShapedType convertedShapedType =
+              shapedType.cloneWith(std::nullopt, convertedElementType);
+          convertedAttrs.emplace_back(
+              attr.getName(),
+              SparseElementsAttr::get(
+                  convertedShapedType, sparseAttr.getIndices(),
+                  sparseAttr.getValues().bitcast(convertedElementType)));
+        } else {
+          convertedAttrs.emplace_back(attr);
+        }
       }
     } else {
       convertedAttrs.push_back(attr);



More information about the Mlir-commits mailing list