[Mlir-commits] [mlir] 47fc399 - [MLIR] Extend the extractvalue fold method (#172297)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 17 06:04:06 PST 2025


Author: Vadim Curcă
Date: 2025-12-17T15:04:00+01:00
New Revision: 47fc3992ba3cd5fe44036a640cb1a5f03f5ad439

URL: https://github.com/llvm/llvm-project/commit/47fc3992ba3cd5fe44036a640cb1a5f03f5ad439
DIFF: https://github.com/llvm/llvm-project/commit/47fc3992ba3cd5fe44036a640cb1a5f03f5ad439.diff

LOG: [MLIR] Extend the extractvalue fold method (#172297)

Extend the `extractvalue` fold method to support extracting from
constant containers, such as `llvm.mlir.zero`, `llvm.mlir.undef`,
`llvm.mlir.poison`, and `llvm.mlir.constant` holding `ElementsAttr` or
`ArrayAttr`.

Added: 
    

Modified: 
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Dialect/LLVMIR/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5b819485b1be4..f9162b35966c1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1898,6 +1898,27 @@ static Type getInsertExtractValueElementType(Type llvmType,
   return llvmType;
 }
 
+/// Extracts the element at the given index from an attribute. For
+/// `ElementsAttr` and `ArrayAttr`, returns the element at the specified index.
+/// For `ZeroAttr`, `UndefAttr`, and `PoisonAttr`, returns the attribute itself
+/// unchanged. Returns `nullptr` if the attribute is not one of these types or
+/// if the index is out of bounds.
+static Attribute extractElementAt(Attribute attr, size_t index) {
+  if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
+    if (index < static_cast<size_t>(elementsAttr.getNumElements()))
+      return elementsAttr.getValues<Attribute>()[index];
+    return nullptr;
+  }
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    if (index < arrayAttr.getValue().size())
+      return arrayAttr[index];
+    return nullptr;
+  }
+  if (isa<ZeroAttr, UndefAttr, PoisonAttr>(attr))
+    return attr;
+  return nullptr;
+}
+
 OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
   if (auto extractValueOp = getContainer().getDefiningOp<ExtractValueOp>()) {
     SmallVector<int64_t, 4> newPos(extractValueOp.getPosition());
@@ -1907,22 +1928,11 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     return getResult();
   }
 
-  {
-    DenseElementsAttr constval;
-    matchPattern(getContainer(), m_Constant(&constval));
-    if (constval && constval.getElementType() == getType()) {
-      if (isa<SplatElementsAttr>(constval))
-        return constval.getSplatValue<Attribute>();
-      if (getPosition().size() == 1)
-        return constval.getValues<Attribute>()[getPosition()[0]];
-    }
-  }
-
-  auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
+  Operation *container = getContainer().getDefiningOp();
   OpFoldResult result = {};
   ArrayRef<int64_t> extractPos = getPosition();
   bool switchedToInsertedValue = false;
-  while (insertValueOp) {
+  while (auto insertValueOp = dyn_cast_if_present<InsertValueOp>(container)) {
     ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
     auto extractPosSize = extractPos.size();
     auto insertPosSize = insertPos.size();
@@ -1945,7 +1955,7 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     // In the above example, %4 is folded to %arg1.
     if (extractPosSize > insertPosSize &&
         extractPos.take_front(insertPosSize) == insertPos) {
-      insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>();
+      container = insertValueOp.getValue().getDefiningOp();
       extractPos = extractPos.drop_front(insertPosSize);
       switchedToInsertedValue = true;
       continue;
@@ -1975,9 +1985,20 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
       getContainerMutable().assign(insertValueOp.getContainer());
       result = getResult();
     }
-    insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
+    container = insertValueOp.getContainer().getDefiningOp();
+  }
+  if (!container)
+    return result;
+
+  Attribute containerAttr;
+  if (!matchPattern(container, m_Constant(&containerAttr)))
+    return nullptr;
+  for (int64_t pos : extractPos) {
+    containerAttr = extractElementAt(containerAttr, pos);
+    if (!containerAttr)
+      return nullptr;
   }
-  return result;
+  return containerAttr;
 }
 
 LogicalResult ExtractValueOp::verify() {

diff  --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 755e3a3a5fa09..8303afc9eb033 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -112,10 +112,10 @@ llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)
 
 // -----
 
-// CHECK-LABEL: fold_extract_const
+// CHECK-LABEL: fold_extract_const_array
 // CHECK-NOT: extractvalue
 // CHECK:  llvm.mlir.constant(5.000000e-01 : f64)
-llvm.func @fold_extract_const() -> f64 {
+llvm.func @fold_extract_const_array() -> f64 {
   %a = llvm.mlir.constant(dense<[-8.900000e+01, 5.000000e-01]> : tensor<2xf64>) : !llvm.array<2 x f64>
   %b = llvm.extractvalue %a[1] : !llvm.array<2 x f64>
   llvm.return %b : f64
@@ -123,6 +123,17 @@ llvm.func @fold_extract_const() -> f64 {
 
 // -----
 
+// CHECK-LABEL: fold_extract_const_struct
+llvm.func @fold_extract_const_struct() -> i32 {
+  // CHECK-NOT: extractvalue
+  // CHECK: llvm.mlir.constant(2 : i32)
+  %a = llvm.mlir.constant([1 : i16, 2 : i32]) : !llvm.struct<(i16, i32)>
+  %b = llvm.extractvalue %a[1] : !llvm.struct<(i16, i32)>
+  llvm.return %b : i32
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_splat
 // CHECK-NOT: extractvalue
 // CHECK: llvm.mlir.constant(-8.900000e+01 : f64)
@@ -134,6 +145,90 @@ llvm.func @fold_extract_splat() -> f64 {
 
 // -----
 
+// CHECK-LABEL: fold_extract_splat_nested
+llvm.func @fold_extract_splat_nested() -> i32 {
+  // CHECK-NOT: extractvalue
+  // CHECK: llvm.mlir.constant(1 : i32)
+  %a = llvm.mlir.constant(dense<(0, 1)> : tensor<2xcomplex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+  %b = llvm.extractvalue %a[1, 1] : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+  llvm.return %b : i32
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_sparse
+llvm.func @fold_extract_sparse() -> f32 {
+  // CHECK-NOT: extractvalue
+  // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0.000000e+00 : f32)
+  // CHECK-DAG: %[[C42:.*]] = llvm.mlir.constant(4.200000e+01 : f32)
+  %0 = llvm.mlir.constant(sparse<[0], [4.2e+01]> : tensor<4xf32>) : !llvm.array<4 x f32>
+  %1 = llvm.extractvalue %0[0] : !llvm.array<4 x f32>
+  %2 = llvm.extractvalue %0[1] : !llvm.array<4 x f32>
+  // CHECK: llvm.fadd %[[C42]], %[[C0]]
+  %3 = llvm.fadd %1, %2 : f32
+  llvm.return %3 : f32
+}
+
+// -----
+
+// CHECK-LABEL: fold_zero
+llvm.func @fold_zero() -> i32 {
+  // CHECK-NOT: insertvalue
+  // CHECK-NOT: extractvalue
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.zero : i32
+  %0 = llvm.mlir.zero : !llvm.struct<(i16, i32)>
+
+  %1 = llvm.mlir.undef : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+  %2 = llvm.insertvalue %0, %1[0] : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+  %3 = llvm.extractvalue %2[0, 1] : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+  // CHECK: llvm.return %[[ZERO]]
+  llvm.return %3 : i32
+}
+
+// -----
+
+llvm.func @use_struct(!llvm.struct<(i16, i32)>)
+
+// CHECK-LABEL: fold_undef
+llvm.func @fold_undef() -> i32 {
+  // CHECK-NOT: insertvalue
+  // CHECK-NOT: extractvalue
+  // CHECK-DAG: %[[UNDEF_I32:.*]] = llvm.mlir.undef : i32
+  // CHECK-DAG: %[[UNDEF_STRUCT:.*]] = llvm.mlir.undef : !llvm.struct<(i16, i32)>
+  %0 = llvm.mlir.undef : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+
+  %1 = llvm.extractvalue %0[1] : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+  // CHECK: llvm.call @use_struct(%[[UNDEF_STRUCT]])
+  llvm.call @use_struct(%1) : (!llvm.struct<(i16, i32)>) -> ()
+
+  %2 = llvm.extractvalue %0[1, 1] : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+  // CHECK: llvm.return %[[UNDEF_I32]]
+  llvm.return %2 : i32
+}
+
+// -----
+
+llvm.func @use_array(!llvm.array<8 x f32>)
+
+// CHECK-LABEL: fold_poison
+llvm.func @fold_poison() -> f32 {
+  // CHECK-NOT: insertvalue
+  // CHECK-NOT: extractvalue
+  // CHECK-DAG: %[[POISON_F32:.*]] = llvm.mlir.poison : f32
+  // CHECK-DAG: %[[POISON_ARRAY:.*]] = llvm.mlir.poison : !llvm.array<8 x f32>
+  %0 = llvm.mlir.poison : !llvm.array<2 x !llvm.array<8 x f32>>
+
+  %1 = llvm.extractvalue %0[1] : !llvm.array<2 x !llvm.array<8 x f32>>
+  // CHECK: llvm.call @use_array(%[[POISON_ARRAY]])
+  llvm.call @use_array(%1) : (!llvm.array<8 x f32>) -> ()
+
+  %2 = llvm.extractvalue %0[1, 1] : !llvm.array<2 x !llvm.array<8 x f32>>
+  // CHECK: llvm.return %[[POISON_F32]]
+  llvm.return %2 : f32
+}
+
+// -----
+
 // CHECK-LABEL: fold_bitcast
 // CHECK-SAME: %[[ARG:[[:alnum:]]+]]
 // CHECK-NEXT: llvm.return %[[ARG]]


        


More information about the Mlir-commits mailing list