[Mlir-commits] [mlir] [mlir][vector]add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp. (PR #124399)

lonely eagle llvmlistbot at llvm.org
Sat Feb 8 10:41:47 PST 2025


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/124399

>From b6b43622bd0d721d4d3e8288368a4b6046b03931 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 25 Jan 2025 16:55:50 +0800
Subject: [PATCH 1/5] add foldConstantOp fold function and apply it to
 extractOp and insertOp.

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 44 +++++++++++++++++++
 .../VectorToLLVM/vector-to-llvm.mlir          | 36 +++++++++++++++
 .../Vector/vector-warp-distribute.mlir        |  3 +-
 3 files changed, 81 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3fbfcb4979b4951..1b58be7b358f7bf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1977,6 +1977,46 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
   return fromElementsOp.getElements()[flatIndex];
 }
 
+// If the dynamic operands of `extractOp` or `insertOp` is result of
+// `constantOp`, then fold it.
+template <typename T>
+static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
+  auto staticPosition = op.getStaticPosition().vec();
+  OperandRange dynamicPosition = op.getDynamicPosition();
+
+  // If the dynamic operands is empty, it is returned directly.
+  if (!dynamicPosition.size())
+    return;
+  unsigned index = 0;
+
+  // `opChange` is a flog. If it is true, it means to update `op` in place.
+  bool opChange = false;
+  for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
+    if (!ShapedType::isDynamic(staticPosition[i]))
+      continue;
+    Value position = dynamicPosition[index++];
+
+    // If it is a block parameter, proceed to the next iteration.
+    if (!position.getDefiningOp()) {
+      operands.push_back(position);
+      continue;
+    }
+
+    if (auto constantOp =
+            mlir::dyn_cast<arith::ConstantIndexOp>(position.getDefiningOp())) {
+      opChange = true;
+      staticPosition[i] = constantOp.value();
+      continue;
+    }
+    operands.push_back(position);
+  }
+
+  if (opChange) {
+    op.setStaticPosition(staticPosition);
+    op.getOperation()->setOperands(operands);
+  }
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -1999,6 +2039,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
     return val;
   if (auto val = foldScalarExtractFromFromElements(*this))
     return val;
+  SmallVector<Value> operands = {getVector()};
+  foldConstantOp(*this, operands);
   return OpFoldResult();
 }
 
@@ -3028,6 +3070,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   // (type mismatch).
   if (getNumIndices() == 0 && getSourceType() == getType())
     return getSource();
+  SmallVector<Value> operands = {getSource(), getDest()};
+  foldConstantOp(*this, operands);
   return {};
 }
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 29bed9aae56827e..f8f5f9039bb1469 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4115,3 +4115,39 @@ func.func @step_scalable() -> vector<[4]xindex> {
   %0 = vector.step : vector<[4]xindex>
   return %0 : vector<[4]xindex>
 }
+
+// -----
+
+// CHECK-LABEL: @extract_arith_constnt
+func.func @extract_arith_constnt() -> i32 {
+  %v = arith.constant dense<0> : vector<32x1xi32>
+  %c_0 = arith.constant 0 : index
+  %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
+  return %elem : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %{{.*}} = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
+
+// -----
+
+// CHECK-LABEL: @insert_arith_constnt()
+
+func.func @insert_arith_constnt() -> vector<32x1xi32> {
+  %v = arith.constant dense<0> : vector<32x1xi32>
+  %c_0 = arith.constant 0 : index
+  %c_1 = arith.constant 1 : i32
+  %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
+  return %v_1 : vector<32x1xi32>
+}
+    
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_2]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[VAL_5]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index dbe0b39422369ce..38771f25934495f 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
 
 // CHECK-PROP-LABEL: func.func @vector_extract_1d(
 //   CHECK-PROP-DAG:   %[[C5_I32:.*]] = arith.constant 5 : i32
-//   CHECK-PROP-DAG:   %[[C1:.*]] = arith.constant 1 : index
 //       CHECK-PROP:   %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
 //       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<64xf32>
 //       CHECK-PROP:     gpu.yield %[[V]] : vector<64xf32>
 //       CHECK-PROP:   }
-//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
+//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
 //       CHECK-PROP:   %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle  idx %[[E]], %[[C5_I32]]
 //       CHECK-PROP:   return %[[SHUFFLED]] : f32
 func.func @vector_extract_1d(%laneid: index) -> (f32) {

>From 2d778593ba4ca180e5d7847a39c25104af0a605d Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sun, 26 Jan 2025 10:43:29 +0800
Subject: [PATCH 2/5] add logic result and matchPattern on dynamic position.

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1b58be7b358f7bf..389ed19489f7935 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1980,13 +1980,13 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
 // If the dynamic operands of `extractOp` or `insertOp` is result of
 // `constantOp`, then fold it.
 template <typename T>
-static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
+static LogicalResult foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
   auto staticPosition = op.getStaticPosition().vec();
   OperandRange dynamicPosition = op.getDynamicPosition();
 
   // If the dynamic operands is empty, it is returned directly.
   if (!dynamicPosition.size())
-    return;
+    return failure();
   unsigned index = 0;
 
   // `opChange` is a flog. If it is true, it means to update `op` in place.
@@ -2002,10 +2002,10 @@ static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
       continue;
     }
 
-    if (auto constantOp =
-            mlir::dyn_cast<arith::ConstantIndexOp>(position.getDefiningOp())) {
+    APInt pos;
+    if (matchPattern(position, m_ConstantInt(&pos))) {
       opChange = true;
-      staticPosition[i] = constantOp.value();
+      staticPosition[i] = pos.getSExtValue();
       continue;
     }
     operands.push_back(position);
@@ -2014,7 +2014,9 @@ static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
   if (opChange) {
     op.setStaticPosition(staticPosition);
     op.getOperation()->setOperands(operands);
+    return success();
   }
+  return failure();
 }
 
 OpFoldResult ExtractOp::fold(FoldAdaptor) {
@@ -2040,7 +2042,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
   if (auto val = foldScalarExtractFromFromElements(*this))
     return val;
   SmallVector<Value> operands = {getVector()};
-  foldConstantOp(*this, operands);
+  if (succeeded(foldConstantOp(*this, operands)))
+    return getResult();
   return OpFoldResult();
 }
 
@@ -3071,7 +3074,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   if (getNumIndices() == 0 && getSourceType() == getType())
     return getSource();
   SmallVector<Value> operands = {getSource(), getDest()};
-  foldConstantOp(*this, operands);
+  if (succeeded(foldConstantOp(*this, operands)))
+    return getResult();
   return {};
 }
 

>From 4b77e6b6b01148539646c4d88a0a571a715704a9 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 28 Jan 2025 14:22:17 +0800
Subject: [PATCH 3/5] update function name and add canonicalize test and
 regression test.

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 38 +++++++---------
 .../VectorToLLVM/vector-to-llvm.mlir          | 36 ---------------
 mlir/test/Dialect/Vector/canonicalize.mlir    | 44 +++++++++++++++++++
 .../CPU/extract-insert-fold-constant.mlir     | 34 ++++++++++++++
 4 files changed, 94 insertions(+), 58 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/extract-insert-fold-constant.mlir

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 389ed19489f7935..7660130ca7cfa7c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1979,33 +1979,27 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
 
 // If the dynamic operands of `extractOp` or `insertOp` is result of
 // `constantOp`, then fold it.
-template <typename T>
-static LogicalResult foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
+template <typename OpType, typename AdaptorType>
+static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
+                                         SmallVectorImpl<Value> &operands) {
   auto staticPosition = op.getStaticPosition().vec();
   OperandRange dynamicPosition = op.getDynamicPosition();
-
+  ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
   // If the dynamic operands is empty, it is returned directly.
   if (!dynamicPosition.size())
-    return failure();
+    return {};
   unsigned index = 0;
 
-  // `opChange` is a flog. If it is true, it means to update `op` in place.
+  // `opChange` is a flag. If it is true, it means to update `op` in place.
   bool opChange = false;
   for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
     if (!ShapedType::isDynamic(staticPosition[i]))
       continue;
+    Attribute positionAttr = dynamicPositionAttr[index];
     Value position = dynamicPosition[index++];
-
-    // If it is a block parameter, proceed to the next iteration.
-    if (!position.getDefiningOp()) {
-      operands.push_back(position);
-      continue;
-    }
-
-    APInt pos;
-    if (matchPattern(position, m_ConstantInt(&pos))) {
+    if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
+      staticPosition[i] = attr.getInt();
       opChange = true;
-      staticPosition[i] = pos.getSExtValue();
       continue;
     }
     operands.push_back(position);
@@ -2014,12 +2008,12 @@ static LogicalResult foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
   if (opChange) {
     op.setStaticPosition(staticPosition);
     op.getOperation()->setOperands(operands);
-    return success();
+    return op.getResult();
   }
-  return failure();
+  return {};
 }
 
-OpFoldResult ExtractOp::fold(FoldAdaptor) {
+OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
   // mismatch).
@@ -2042,8 +2036,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
   if (auto val = foldScalarExtractFromFromElements(*this))
     return val;
   SmallVector<Value> operands = {getVector()};
-  if (succeeded(foldConstantOp(*this, operands)))
-    return getResult();
+  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
+    return val;
   return OpFoldResult();
 }
 
@@ -3074,8 +3068,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   if (getNumIndices() == 0 && getSourceType() == getType())
     return getSource();
   SmallVector<Value> operands = {getSource(), getDest()};
-  if (succeeded(foldConstantOp(*this, operands)))
-    return getResult();
+  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
+    return val;
   return {};
 }
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f8f5f9039bb1469..29bed9aae56827e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4115,39 +4115,3 @@ func.func @step_scalable() -> vector<[4]xindex> {
   %0 = vector.step : vector<[4]xindex>
   return %0 : vector<[4]xindex>
 }
-
-// -----
-
-// CHECK-LABEL: @extract_arith_constnt
-func.func @extract_arith_constnt() -> i32 {
-  %v = arith.constant dense<0> : vector<32x1xi32>
-  %c_0 = arith.constant 0 : index
-  %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
-  return %elem : i32
-}
-
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
-// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %{{.*}} = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
-
-// -----
-
-// CHECK-LABEL: @insert_arith_constnt()
-
-func.func @insert_arith_constnt() -> vector<32x1xi32> {
-  %v = arith.constant dense<0> : vector<32x1xi32>
-  %c_0 = arith.constant 0 : index
-  %c_1 = arith.constant 1 : i32
-  %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
-  return %v_1 : vector<32x1xi32>
-}
-    
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
-// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
-// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_2]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[VAL_5]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0eebb6e8d612d41..0f017ffb97223e2 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2979,3 +2979,47 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
     memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_arith_constnt
+
+func.func @extract_arith_constnt() -> i32 {
+  %c1_i32 = arith.constant 1 : i32
+  return %c1_i32 : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
+// CHECK: return %[[VAL_0]] : i32
+
+// -----
+
+// CHECK-LABEL: func @insert_arith_constnt
+
+func.func @insert_arith_constnt() -> vector<4x1xi32> {
+  %v = arith.constant dense<0> : vector<4x1xi32>
+  %c_0 = arith.constant 0 : index
+  %c_1 = arith.constant 1 : i32
+  %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<4x1xi32>
+  return %v_1 : vector<4x1xi32>
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<{{\[\[}}1], [0], [0], [0]]> : vector<4x1xi32>
+// CHECK: return %[[VAL_0]] : vector<4x1xi32>
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_arith_constnt
+
+func.func @insert_extract_arith_constnt() -> i32 {
+  %v = arith.constant dense<0> : vector<32x1xi32>
+  %c_0 = arith.constant 0 : index
+  %c_1 = arith.constant 1 : index
+  %c_2 = arith.constant 2 : i32
+  %v_1 = vector.insert %c_2, %v[%c_1, %c_1] : i32 into vector<32x1xi32>
+  %ret = vector.extract %v_1[%c_1, %c_1] : i32 from vector<32x1xi32>
+  return %ret : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant 2 : i32
+// CHECK: return %[[VAL_0]] : i32
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/extract-insert-fold-constant.mlir b/mlir/test/Integration/Dialect/Vector/CPU/extract-insert-fold-constant.mlir
new file mode 100644
index 000000000000000..1a99626837da556
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/extract-insert-fold-constant.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -test-lower-to-llvm  | \
+// RUN: mlir-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+  %v = arith.constant dense<0> : vector<2x2xi32>
+  %c_0 = arith.constant 0 : index
+  %c_1 = arith.constant 1 : index
+  %i32_0 = arith.constant 0 : i32
+  %i32_1 = arith.constant 1 : i32
+  %i32_2 = arith.constant 2 : i32
+  %i32_3 = arith.constant 3 : i32
+  %v_1 = vector.insert %i32_0, %v[%c_0, %c_0] : i32 into vector<2x2xi32>
+  %v_2 = vector.insert %i32_1, %v_1[%c_0, %c_1] : i32 into vector<2x2xi32>
+  %v_3 = vector.insert %i32_2, %v_2[%c_1, %c_0] : i32 into vector<2x2xi32>
+  %v_4 = vector.insert %i32_3, %v_3[%c_1, %c_1] : i32 into vector<2x2xi32>
+  // CHECK: ( ( 0, 1 ), ( 2, 3 ) ) 
+  vector.print %v_4 : vector<2x2xi32>
+  %v_5 = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
+  // CHECK: 0
+  %i32_4 = vector.extract %v_5[%c_0, %c_0] : i32 from vector<2x2xi32>
+  // CHECK: 1
+  %i32_5 = vector.extract %v_5[%c_0, %c_1] : i32 from vector<2x2xi32>
+  // CHECK: 2
+  %i32_6 = vector.extract %v_5[%c_1, %c_0] : i32 from vector<2x2xi32>
+  // CHECK: 3
+  %i32_7 = vector.extract %v_5[%c_1, %c_1] : i32 from vector<2x2xi32>
+  vector.print %i32_4 : i32
+  vector.print %i32_5 : i32
+  vector.print %i32_6 : i32
+  vector.print %i32_7 : i32
+  return
+}

>From a2a5c930b7dae074d4d49fbccd2a44034bcc690a Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Thu, 30 Jan 2025 10:16:05 +0800
Subject: [PATCH 4/5] update test.

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  7 +-
 .../VectorToLLVM/vector-to-llvm.mlir          | 82 ++++++++++---------
 mlir/test/Dialect/Vector/canonicalize.mlir    | 22 ++---
 3 files changed, 57 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 28274bfd06a29c3..478d5737364ef71 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1986,17 +1986,20 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
   return fromElementsOp.getElements()[flatIndex];
 }
 
-// If the dynamic operands of `extractOp` or `insertOp` is result of
+// If the dynamic indices of `extractOp` or `insertOp` are result of
 // `constantOp`, then fold it.
 template <typename OpType, typename AdaptorType>
 static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
                                          SmallVectorImpl<Value> &operands) {
-  auto staticPosition = op.getStaticPosition().vec();
+  std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
   OperandRange dynamicPosition = op.getDynamicPosition();
   ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
+
   // If the dynamic operands is empty, it is returned directly.
   if (!dynamicPosition.size())
     return {};
+
+  // `index` is used to iterate over the `dynamicPosition`.
   unsigned index = 0;
 
   // `opChange` is a flag. If it is true, it means to update `op` in place.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 6d9fc5c10857aed..bce5d9bb67c9c62 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1473,6 +1473,26 @@ func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
 //       CHECK:   %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
 //       CHECK:   return %[[T3]] : index
 
+
+// -----
+
+func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<32x1xi32>) -> i32 {
+  %0 = arith.constant 0 : index
+  %1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
+  return %1 : i32
+}
+
+// Compile-time if the indices of extractOp if constants, the constants will be collapsed,
+// the constants are folded away, hence the lowering works.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
+//  CHECK-SAME:   %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+//       CHECK:   %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<32 x vector<1xi32>>
+//       CHECK:   %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[RES:.*]] = llvm.extractelement %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
+//       CHECK:   return %[[RES]] : i32
+
 // -----
 
 //===----------------------------------------------------------------------===//
@@ -1726,6 +1746,29 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
 
 // -----
 
+func.func @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : i32
+  %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
+  return %res : vector<4x1xi32>
+}
+
+// Compile-time if the indices of insertOp if constants, the constants will be collapsed,
+// the constants are folded away, hence the lowering works.
+
+// CHECK-LABEL: @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
+//  CHECK-SAME:   %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x1xi32> to !llvm.array<4 x vector<1xi32>>
+//       CHECK:   %[[C1:.*]] = arith.constant 1 : i32
+//       CHECK:   %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
+//       CHECK:   %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[VEC_1:.*]] = llvm.insertelement %[[C1]], %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
+//       CHECK:   %[[VEC_2:.*]] = llvm.insertvalue %[[VEC_1]], %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
+//       CHECK:   %[[RES:.*]] = builtin.unrealized_conversion_cast %[[VEC_2]] : !llvm.array<4 x vector<1xi32>> to vector<4x1xi32>
+//       CHECK:   return %[[RES]] : vector<4x1xi32>
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.type_cast
 //
@@ -4125,42 +4168,3 @@ func.func @step_scalable() -> vector<[4]xindex> {
   %0 = vector.step : vector<[4]xindex>
   return %0 : vector<[4]xindex>
 }
-
-// -----
-
-// CHECK-LABEL: func @fold_extract_constant_indices
-
-func.func @fold_extract_constant_indices(%arg : vector<32x1xi32>) -> i32 {
-  %0 = arith.constant 0 : index
-  %1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
-  return %1 : i32
-}
-
-// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
-// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %[[RES:.*]] = llvm.extractelement %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
-// CHECK:  return %[[RES]] : i32
-
-// -----
-
-// CHECK-LABEL: func @fold_insert_constant_indices
-
-func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
-  %0 = arith.constant 0 : index
-  %1 = arith.constant 1 : i32
-  %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
-  return %res : vector<4x1xi32>
-}
-
-
-// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
-// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x1xi32> to !llvm.array<4 x vector<1xi32>>
-// CHECK: %[[C1:.*]] = arith.constant 1 : i32
-// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %[[VEC_1:.*]] = llvm.insertelement %[[C1]], %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
-// CHECK: %[[VEC_2:.*]] = llvm.insertvalue %[[VEC_1]], %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
-// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[VEC_2]] : !llvm.array<4 x vector<1xi32>> to vector<4x1xi32>
-// CHECK: return %[[RES]] : vector<4x1xi32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b940ff27d67d975..965777f19457977 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3045,30 +3045,26 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
 
 // -----
 
-// CHECK-LABEL: func @fold_extract_constant_indices
-
+// CHECK-LABEL: @fold_extract_constant_indices
+//   CHECK-SAME:   %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
+//        CHECK:   %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>
+//        CHECK:   return %[[RES]] : i32
 func.func @fold_extract_constant_indices(%arg : vector<32x1xi32>) -> i32 {
   %0 = arith.constant 0 : index
   %1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
   return %1 : i32
 }
 
-// CHECK-SAME:  %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
-// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>
-// CHECK: return %[[RES]] : i32
-
 // -----
 
-// CHECK-LABEL: func @fold_insert_constant_indices
-
+// CHECK-LABEL: @fold_insert_constant_indices
+//  CHECK-SAME:   %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
+//       CHECK:   %[[VAL:.*]] = arith.constant 1 : i32
+//       CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, 0] : i32 into vector<4x1xi32>
+//       CHECK:   return %[[RES]] : vector<4x1xi32>
 func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
   %0 = arith.constant 0 : index
   %1 = arith.constant 1 : i32
   %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
   return %res : vector<4x1xi32>
 }
-
-// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
-// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
-// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, 0] : i32 into vector<4x1xi32>
-// CHECK:  return %[[RES]] : vector<4x1xi32>

>From 7a469d47a57f3004e65d1f00068838ec89653d2b Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Wed, 5 Feb 2025 01:49:50 +0800
Subject: [PATCH 5/5] update comment.

---
 mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index bce5d9bb67c9c62..1b1d7660578bd24 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1482,8 +1482,8 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg :
   return %1 : i32
 }
 
-// Compile-time if the indices of extractOp if constants, the constants will be collapsed,
-// the constants are folded away, hence the lowering works.
+// At compile time, since the indices of extractOp are constants,
+// they will be collapsed and folded away; therefore, the lowering works.
 
 // CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
 //  CHECK-SAME:   %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
@@ -1753,8 +1753,8 @@ func.func @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg :
   return %res : vector<4x1xi32>
 }
 
-// Compile-time if the indices of insertOp if constants, the constants will be collapsed,
-// the constants are folded away, hence the lowering works.
+// At compile time, since the indices of insertOp are constants,
+// they will be collapsed and folded away; therefore, the lowering works.
 
 // CHECK-LABEL: @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
 //  CHECK-SAME:   %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {



More information about the Mlir-commits mailing list