[Mlir-commits] [mlir] 47fc399 - [MLIR] Extend the extractvalue fold method (#172297)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 17 06:04:06 PST 2025
Author: Vadim Curcă
Date: 2025-12-17T15:04:00+01:00
New Revision: 47fc3992ba3cd5fe44036a640cb1a5f03f5ad439
URL: https://github.com/llvm/llvm-project/commit/47fc3992ba3cd5fe44036a640cb1a5f03f5ad439
DIFF: https://github.com/llvm/llvm-project/commit/47fc3992ba3cd5fe44036a640cb1a5f03f5ad439.diff
LOG: [MLIR] Extend the extractvalue fold method (#172297)
Extend the `extractvalue` fold method to support extracting from
constant containers, such as `llvm.mlir.zero`, `llvm.mlir.undef`,
`llvm.mlir.poison`, and `llvm.mlir.constant` holding `ElementsAttr` or
`ArrayAttr`.
Added:
Modified:
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5b819485b1be4..f9162b35966c1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1898,6 +1898,27 @@ static Type getInsertExtractValueElementType(Type llvmType,
return llvmType;
}
+/// Extracts the element at the given index from an attribute. For
+/// `ElementsAttr` and `ArrayAttr`, returns the element at the specified index.
+/// For `ZeroAttr`, `UndefAttr`, and `PoisonAttr`, returns the attribute itself
+/// unchanged. Returns `nullptr` if the attribute is not one of these types or
+/// if the index is out of bounds.
+static Attribute extractElementAt(Attribute attr, size_t index) {
+ if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
+ if (index < static_cast<size_t>(elementsAttr.getNumElements()))
+ return elementsAttr.getValues<Attribute>()[index];
+ return nullptr;
+ }
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+ if (index < arrayAttr.getValue().size())
+ return arrayAttr[index];
+ return nullptr;
+ }
+ if (isa<ZeroAttr, UndefAttr, PoisonAttr>(attr))
+ return attr;
+ return nullptr;
+}
+
OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
if (auto extractValueOp = getContainer().getDefiningOp<ExtractValueOp>()) {
SmallVector<int64_t, 4> newPos(extractValueOp.getPosition());
@@ -1907,22 +1928,11 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
return getResult();
}
- {
- DenseElementsAttr constval;
- matchPattern(getContainer(), m_Constant(&constval));
- if (constval && constval.getElementType() == getType()) {
- if (isa<SplatElementsAttr>(constval))
- return constval.getSplatValue<Attribute>();
- if (getPosition().size() == 1)
- return constval.getValues<Attribute>()[getPosition()[0]];
- }
- }
-
- auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
+ Operation *container = getContainer().getDefiningOp();
OpFoldResult result = {};
ArrayRef<int64_t> extractPos = getPosition();
bool switchedToInsertedValue = false;
- while (insertValueOp) {
+ while (auto insertValueOp = dyn_cast_if_present<InsertValueOp>(container)) {
ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
auto extractPosSize = extractPos.size();
auto insertPosSize = insertPos.size();
@@ -1945,7 +1955,7 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// In the above example, %4 is folded to %arg1.
if (extractPosSize > insertPosSize &&
extractPos.take_front(insertPosSize) == insertPos) {
- insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>();
+ container = insertValueOp.getValue().getDefiningOp();
extractPos = extractPos.drop_front(insertPosSize);
switchedToInsertedValue = true;
continue;
@@ -1975,9 +1985,20 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
getContainerMutable().assign(insertValueOp.getContainer());
result = getResult();
}
- insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
+ container = insertValueOp.getContainer().getDefiningOp();
+ }
+ if (!container)
+ return result;
+
+ Attribute containerAttr;
+ if (!matchPattern(container, m_Constant(&containerAttr)))
+ return nullptr;
+ for (int64_t pos : extractPos) {
+ containerAttr = extractElementAt(containerAttr, pos);
+ if (!containerAttr)
+ return nullptr;
}
- return result;
+ return containerAttr;
}
LogicalResult ExtractValueOp::verify() {
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 755e3a3a5fa09..8303afc9eb033 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -112,10 +112,10 @@ llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)
// -----
-// CHECK-LABEL: fold_extract_const
+// CHECK-LABEL: fold_extract_const_array
// CHECK-NOT: extractvalue
// CHECK: llvm.mlir.constant(5.000000e-01 : f64)
-llvm.func @fold_extract_const() -> f64 {
+llvm.func @fold_extract_const_array() -> f64 {
%a = llvm.mlir.constant(dense<[-8.900000e+01, 5.000000e-01]> : tensor<2xf64>) : !llvm.array<2 x f64>
%b = llvm.extractvalue %a[1] : !llvm.array<2 x f64>
llvm.return %b : f64
@@ -123,6 +123,17 @@ llvm.func @fold_extract_const() -> f64 {
// -----
+// CHECK-LABEL: fold_extract_const_struct
+llvm.func @fold_extract_const_struct() -> i32 {
+ // CHECK-NOT: extractvalue
+ // CHECK: llvm.mlir.constant(2 : i32)
+ %a = llvm.mlir.constant([1 : i16, 2 : i32]) : !llvm.struct<(i16, i32)>
+ %b = llvm.extractvalue %a[1] : !llvm.struct<(i16, i32)>
+ llvm.return %b : i32
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_splat
// CHECK-NOT: extractvalue
// CHECK: llvm.mlir.constant(-8.900000e+01 : f64)
@@ -134,6 +145,90 @@ llvm.func @fold_extract_splat() -> f64 {
// -----
+// CHECK-LABEL: fold_extract_splat_nested
+llvm.func @fold_extract_splat_nested() -> i32 {
+ // CHECK-NOT: extractvalue
+ // CHECK: llvm.mlir.constant(1 : i32)
+ %a = llvm.mlir.constant(dense<(0, 1)> : tensor<2xcomplex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+ %b = llvm.extractvalue %a[1, 1] : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+ llvm.return %b : i32
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_sparse
+llvm.func @fold_extract_sparse() -> f32 {
+ // CHECK-NOT: extractvalue
+ // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0.000000e+00 : f32)
+ // CHECK-DAG: %[[C42:.*]] = llvm.mlir.constant(4.200000e+01 : f32)
+ %0 = llvm.mlir.constant(sparse<[0], [4.2e+01]> : tensor<4xf32>) : !llvm.array<4 x f32>
+ %1 = llvm.extractvalue %0[0] : !llvm.array<4 x f32>
+ %2 = llvm.extractvalue %0[1] : !llvm.array<4 x f32>
+ // CHECK: llvm.fadd %[[C42]], %[[C0]]
+ %3 = llvm.fadd %1, %2 : f32
+ llvm.return %3 : f32
+}
+
+// -----
+
+// CHECK-LABEL: fold_zero
+llvm.func @fold_zero() -> i32 {
+ // CHECK-NOT: insertvalue
+ // CHECK-NOT: extractvalue
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.zero : i32
+ %0 = llvm.mlir.zero : !llvm.struct<(i16, i32)>
+
+ %1 = llvm.mlir.undef : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+ %2 = llvm.insertvalue %0, %1[0] : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+ %3 = llvm.extractvalue %2[0, 1] : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+ // CHECK: llvm.return %[[ZERO]]
+ llvm.return %3 : i32
+}
+
+// -----
+
+llvm.func @use_struct(!llvm.struct<(i16, i32)>)
+
+// CHECK-LABEL: fold_undef
+llvm.func @fold_undef() -> i32 {
+ // CHECK-NOT: insertvalue
+ // CHECK-NOT: extractvalue
+ // CHECK-DAG: %[[UNDEF_I32:.*]] = llvm.mlir.undef : i32
+ // CHECK-DAG: %[[UNDEF_STRUCT:.*]] = llvm.mlir.undef : !llvm.struct<(i16, i32)>
+ %0 = llvm.mlir.undef : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+
+ %1 = llvm.extractvalue %0[1] : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+ // CHECK: llvm.call @use_struct(%[[UNDEF_STRUCT]])
+ llvm.call @use_struct(%1) : (!llvm.struct<(i16, i32)>) -> ()
+
+ %2 = llvm.extractvalue %0[1, 1] : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+ // CHECK: llvm.return %[[UNDEF_I32]]
+ llvm.return %2 : i32
+}
+
+// -----
+
+llvm.func @use_array(!llvm.array<8 x f32>)
+
+// CHECK-LABEL: fold_poison
+llvm.func @fold_poison() -> f32 {
+ // CHECK-NOT: insertvalue
+ // CHECK-NOT: extractvalue
+ // CHECK-DAG: %[[POISON_F32:.*]] = llvm.mlir.poison : f32
+ // CHECK-DAG: %[[POISON_ARRAY:.*]] = llvm.mlir.poison : !llvm.array<8 x f32>
+ %0 = llvm.mlir.poison : !llvm.array<2 x !llvm.array<8 x f32>>
+
+ %1 = llvm.extractvalue %0[1] : !llvm.array<2 x !llvm.array<8 x f32>>
+ // CHECK: llvm.call @use_array(%[[POISON_ARRAY]])
+ llvm.call @use_array(%1) : (!llvm.array<8 x f32>) -> ()
+
+ %2 = llvm.extractvalue %0[1, 1] : !llvm.array<2 x !llvm.array<8 x f32>>
+ // CHECK: llvm.return %[[POISON_F32]]
+ llvm.return %2 : f32
+}
+
+// -----
+
// CHECK-LABEL: fold_bitcast
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
// CHECK-NEXT: llvm.return %[[ARG]]
More information about the Mlir-commits
mailing list