[Mlir-commits] [mlir] [mlir][vector]add foldConstantOp fold function and apply it to extractOp and insertOp. (PR #124399)

lonely eagle llvmlistbot at llvm.org
Sat Jan 25 01:00:55 PST 2025


https://github.com/linuxlonelyeagle created https://github.com/llvm/llvm-project/pull/124399

see https://github.com/llvm/llvm-project/pull/121631
Ping: @Groverkss @ftynse I took a more clever approach to implement it, which I think should be good.

>From 3ff7e5924887d095b52a180cb133fda446356799 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 25 Jan 2025 16:55:50 +0800
Subject: [PATCH] add foldConstantOp fold function and apply it to extractOp
 and insertOp.

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 45 +++++++++++++++++++
 .../VectorToLLVM/vector-to-llvm.mlir          | 36 +++++++++++++++
 .../Vector/vector-warp-distribute.mlir        |  3 +-
 3 files changed, 82 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3fbfcb4979b495..5021b097fc5ef6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -1977,6 +1978,46 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
   return fromElementsOp.getElements()[flatIndex];
 }
 
+// If the dynamic operands of `extractOp` or `insertOp` is result of
+// `constantOp`, then fold it.
+template <typename T>
+static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
+  auto staticPosition = op.getStaticPosition().vec();
+  OperandRange dynamicPosition = op.getDynamicPosition();
+
+  // If the dynamic operands is empty, it is returned directly.
+  if (!dynamicPosition.size())
+    return;
+  unsigned index = 0;
+
+  // `opChange` is a flog. If it is true, it means to update `op` in place.
+  bool opChange = false;
+  for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
+    if (!ShapedType::isDynamic(staticPosition[i]))
+      continue;
+    Value position = dynamicPosition[index++];
+
+    // If it is a block parameter, proceed to the next iteration.
+    if (!position.getDefiningOp()) {
+      operands.push_back(position);
+      continue;
+    }
+
+    if (auto constantOp =
+            mlir::dyn_cast<arith::ConstantIndexOp>(position.getDefiningOp())) {
+      opChange = true;
+      staticPosition[i] = constantOp.value();
+      continue;
+    }
+    operands.push_back(position);
+  }
+
+  if (opChange) {
+    op.setStaticPosition(staticPosition);
+    op.getOperation()->setOperands(operands);
+  }
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -1999,6 +2040,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
     return val;
   if (auto val = foldScalarExtractFromFromElements(*this))
     return val;
+  SmallVector<Value> operands = {getVector()};
+  foldConstantOp(*this, operands);
   return OpFoldResult();
 }
 
@@ -3028,6 +3071,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   // (type mismatch).
   if (getNumIndices() == 0 && getSourceType() == getType())
     return getSource();
+  SmallVector<Value> operands = {getSource(), getDest()};
+  foldConstantOp(*this, operands);
   return {};
 }
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 29bed9aae56827..f8f5f9039bb146 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4115,3 +4115,39 @@ func.func @step_scalable() -> vector<[4]xindex> {
   %0 = vector.step : vector<[4]xindex>
   return %0 : vector<[4]xindex>
 }
+
+// -----
+
+// CHECK-LABEL: @extract_arith_constnt
+func.func @extract_arith_constnt() -> i32 {
+  %v = arith.constant dense<0> : vector<32x1xi32>
+  %c_0 = arith.constant 0 : index
+  %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
+  return %elem : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %{{.*}} = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
+
+// -----
+
+// CHECK-LABEL: @insert_arith_constnt()
+
+func.func @insert_arith_constnt() -> vector<32x1xi32> {
+  %v = arith.constant dense<0> : vector<32x1xi32>
+  %c_0 = arith.constant 0 : index
+  %c_1 = arith.constant 1 : i32
+  %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
+  return %v_1 : vector<32x1xi32>
+}
+    
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_2]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[VAL_5]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index dbe0b39422369c..38771f25934495 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
 
 // CHECK-PROP-LABEL: func.func @vector_extract_1d(
 //   CHECK-PROP-DAG:   %[[C5_I32:.*]] = arith.constant 5 : i32
-//   CHECK-PROP-DAG:   %[[C1:.*]] = arith.constant 1 : index
 //       CHECK-PROP:   %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
 //       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<64xf32>
 //       CHECK-PROP:     gpu.yield %[[V]] : vector<64xf32>
 //       CHECK-PROP:   }
-//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
+//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
 //       CHECK-PROP:   %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle  idx %[[E]], %[[C5_I32]]
 //       CHECK-PROP:   return %[[SHUFFLED]] : f32
 func.func @vector_extract_1d(%laneid: index) -> (f32) {



More information about the Mlir-commits mailing list