[Mlir-commits] [mlir] [mlir][arith.constant]Fix element type of the dense attributes in target attributes to be consistent with result type in LLVM::detail::oneToOneRewrite() (PR #149787)
Mengmeng Sun
llvmlistbot at llvm.org
Mon Jul 21 02:46:31 PDT 2025
https://github.com/MengmSun created https://github.com/llvm/llvm-project/pull/149787
As I described in [[MLIR][Vector]Add constraints to vector.shape_cast(constant) -> constant](https://github.com/llvm/llvm-project/pull/147691), we have a case as below after `convert-to-llvm`.
```bash
...
%4 = llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>
%8 = vector.shape_cast %4 : vector<192xi8> to vector<1x192xi8>
%10 = vector.extract %8[0] : vector<192xi8> from vector<1x192xi8>
...
```
Our next pass is `Canonocalizer`. And after [#133988 moved the canonicalizer to a folder](https://github.com/llvm/llvm-project/pull/133988) merged seve months ago we met the problem in the `Canonicalizer` pass:
```bash
mlir::DenseElementsAttr mlir::DenseElementsAttr::reshape(mlir::ShapedType): Assertion `newType.getElementType() == curType.getElementType() && "expected the same element type"' failed.
```
That's because ` llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>` lowered from `arith.constant dense<0.000000e+00> : vector<192xf8E4M3FN>`. The element type `f8E4M3FN` of the result type is converted to `i8` with `typeConverter` . However, the element type of the dense attributes has not been converted. And the target attributes kept the same and passed to the new replaced op `llvm.mlir.constant`. Then our problem exposed.
Before we try to just WAR in `ShapeCastOp::fold()`. This can solve our problem. However as @dcaballe and @banach-space pointed out it's better to solve problems on root instead of maintaining other incorrect code. Theoretically, the target attributes in `LLVM::detail::oneToOneRewrite()` maybe the same as source attributes as the current implementation, but not for all cases.
So in this MR, we tried to fix element type of the dense attributes in target attributes to be consistent with result type. In this fix, the `arith.constant dense<0.000000e+00> : vector<192xf8E4M3FN>` will be converted to `llvm.mlir.constant(dense<0> : vector<192xi8>) : vector<192xi8>`. It will not cause any accuracy loss of the dense value and as my UT shows it just reinterprets.
>From d4e6a3ff33dc82e574f92b01e7613a007f63a7d9 Mon Sep 17 00:00:00 2001
From: MengmengSun <mengmengs at nvidia.com>
Date: Mon, 21 Jul 2025 02:01:35 -0700
Subject: [PATCH] Fix element type of target attributes in oneToOneRewrite when
converting to llvm
---
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 27 ++++++++++++++++++-
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 17 +++++++++++-
2 files changed, 42 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index c5f72f7e10b8c..329703e4f054d 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -331,10 +331,35 @@ LogicalResult LLVM::detail::oneToOneRewrite(
return failure();
}
+ // If the targetAttrs contains DenseElementsAttr,
+ // and the element type of the DenseElementsAttr and result type is
+ // inconsistent after the conversion of result types, we need to convert the
+ // element type of the DenseElementsAttr to the target type by creating a new
+ // DenseElementsAttr with the converted element type, and use the new
+ // DenseElementsAttr to replace the old one in the targetAttrs
+ SmallVector<NamedAttribute> convertedAttrs;
+ for (auto attr : targetAttrs) {
+ if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
+ VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType());
+ if (vectorType) {
+ auto convertedElementType =
+ typeConverter.convertType(vectorType.getElementType());
+ VectorType convertedVectorType =
+ VectorType::get(vectorType.getShape(), convertedElementType,
+ vectorType.getScalableDims());
+ convertedAttrs.emplace_back(
+ attr.getName(), DenseElementsAttr::getFromRawBuffer(
+ convertedVectorType, denseAttr.getRawData()));
+ }
+ } else {
+ convertedAttrs.push_back(attr);
+ }
+ }
+
// Create the operation through state since we don't know its C++ type.
Operation *newOp =
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
- resultTypes, targetAttrs);
+ resultTypes, convertedAttrs);
setNativeProperties(newOp, overflowFlags);
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 83bdbe1f67118..299cc32351bdb 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -428,7 +428,7 @@ func.func @fcmp(f32, f32) -> () {
// CHECK-LABEL: @index_vector
func.func @index_vector(%arg0: vector<4xindex>) {
- // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64>
+ // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xi64>) : vector<4xi64>
%0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK: %[[V:.*]] = llvm.add %{{.*}}, %[[CST]] : vector<4xi64>
%1 = arith.addi %arg0, %0 : vector<4xindex>
@@ -437,6 +437,21 @@ func.func @index_vector(%arg0: vector<4xindex>) {
// -----
+// CHECK-LABEL: @f8E4M3FN_vector
+func.func @f8E4M3FN_vector() -> vector<4xf8E4M3FN> {
+ // CHECK: %[[CST0:.*]] = llvm.mlir.constant(dense<0> : vector<4xi8>) : vector<4xi8>
+ %0 = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FN>
+ // CHECK: %[[CST1:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
+ %1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf8E4M3FN>
+ // CHECK: %[[V:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
+ %2 = arith.addf %0, %1 : vector<4xf8E4M3FN>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V]] : vector<4xi8> to vector<4xf8E4M3FN>
+ // CHECK-NEXT: return %[[RES]] : vector<4xf8E4M3FN>
+ func.return %2 : vector<4xf8E4M3FN>
+}
+
+// -----
+
// CHECK-LABEL: @bitcast_1d
func.func @bitcast_1d(%arg0: vector<2xf32>) {
// CHECK: llvm.bitcast %{{.*}} : vector<2xf32> to vector<2xi32>
More information about the Mlir-commits
mailing list