[Mlir-commits] [mlir] [mlir][vector]add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp. (PR #124399)
lonely eagle
llvmlistbot at llvm.org
Sat Feb 8 10:41:47 PST 2025
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/124399
>From b6b43622bd0d721d4d3e8288368a4b6046b03931 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 25 Jan 2025 16:55:50 +0800
Subject: [PATCH 1/5] add foldConstantOp fold function and apply it to
extractOp and insertOp.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 44 +++++++++++++++++++
.../VectorToLLVM/vector-to-llvm.mlir | 36 +++++++++++++++
.../Vector/vector-warp-distribute.mlir | 3 +-
3 files changed, 81 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3fbfcb4979b4951..1b58be7b358f7bf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1977,6 +1977,46 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return fromElementsOp.getElements()[flatIndex];
}
+// If the dynamic operands of `extractOp` or `insertOp` is result of
+// `constantOp`, then fold it.
+template <typename T>
+static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
+ auto staticPosition = op.getStaticPosition().vec();
+ OperandRange dynamicPosition = op.getDynamicPosition();
+
+ // If the dynamic operands is empty, it is returned directly.
+ if (!dynamicPosition.size())
+ return;
+ unsigned index = 0;
+
+ // `opChange` is a flog. If it is true, it means to update `op` in place.
+ bool opChange = false;
+ for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
+ if (!ShapedType::isDynamic(staticPosition[i]))
+ continue;
+ Value position = dynamicPosition[index++];
+
+ // If it is a block parameter, proceed to the next iteration.
+ if (!position.getDefiningOp()) {
+ operands.push_back(position);
+ continue;
+ }
+
+ if (auto constantOp =
+ mlir::dyn_cast<arith::ConstantIndexOp>(position.getDefiningOp())) {
+ opChange = true;
+ staticPosition[i] = constantOp.value();
+ continue;
+ }
+ operands.push_back(position);
+ }
+
+ if (opChange) {
+ op.setStaticPosition(staticPosition);
+ op.getOperation()->setOperands(operands);
+ }
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -1999,6 +2039,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
+ SmallVector<Value> operands = {getVector()};
+ foldConstantOp(*this, operands);
return OpFoldResult();
}
@@ -3028,6 +3070,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
+ SmallVector<Value> operands = {getSource(), getDest()};
+ foldConstantOp(*this, operands);
return {};
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 29bed9aae56827e..f8f5f9039bb1469 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4115,3 +4115,39 @@ func.func @step_scalable() -> vector<[4]xindex> {
%0 = vector.step : vector<[4]xindex>
return %0 : vector<[4]xindex>
}
+
+// -----
+
+// CHECK-LABEL: @extract_arith_constnt
+func.func @extract_arith_constnt() -> i32 {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
+ return %elem : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %{{.*}} = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
+
+// -----
+
+// CHECK-LABEL: @insert_arith_constnt()
+
+func.func @insert_arith_constnt() -> vector<32x1xi32> {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %c_1 = arith.constant 1 : i32
+ %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
+ return %v_1 : vector<32x1xi32>
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_2]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[VAL_5]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index dbe0b39422369ce..38771f25934495f 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
// CHECK-PROP-LABEL: func.func @vector_extract_1d(
// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32
-// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: gpu.yield %[[V]] : vector<64xf32>
// CHECK-PROP: }
-// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
// CHECK-PROP: return %[[SHUFFLED]] : f32
func.func @vector_extract_1d(%laneid: index) -> (f32) {
>From 2d778593ba4ca180e5d7847a39c25104af0a605d Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sun, 26 Jan 2025 10:43:29 +0800
Subject: [PATCH 2/5] add logic result and matchPattern on dynamic position.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 18 +++++++++++-------
1 file changed, 11 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1b58be7b358f7bf..389ed19489f7935 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1980,13 +1980,13 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
// If the dynamic operands of `extractOp` or `insertOp` is result of
// `constantOp`, then fold it.
template <typename T>
-static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
+static LogicalResult foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
auto staticPosition = op.getStaticPosition().vec();
OperandRange dynamicPosition = op.getDynamicPosition();
// If the dynamic operands is empty, it is returned directly.
if (!dynamicPosition.size())
- return;
+ return failure();
unsigned index = 0;
// `opChange` is a flog. If it is true, it means to update `op` in place.
@@ -2002,10 +2002,10 @@ static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
continue;
}
- if (auto constantOp =
- mlir::dyn_cast<arith::ConstantIndexOp>(position.getDefiningOp())) {
+ APInt pos;
+ if (matchPattern(position, m_ConstantInt(&pos))) {
opChange = true;
- staticPosition[i] = constantOp.value();
+ staticPosition[i] = pos.getSExtValue();
continue;
}
operands.push_back(position);
@@ -2014,7 +2014,9 @@ static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
if (opChange) {
op.setStaticPosition(staticPosition);
op.getOperation()->setOperands(operands);
+ return success();
}
+ return failure();
}
OpFoldResult ExtractOp::fold(FoldAdaptor) {
@@ -2040,7 +2042,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
SmallVector<Value> operands = {getVector()};
- foldConstantOp(*this, operands);
+ if (succeeded(foldConstantOp(*this, operands)))
+ return getResult();
return OpFoldResult();
}
@@ -3071,7 +3074,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
SmallVector<Value> operands = {getSource(), getDest()};
- foldConstantOp(*this, operands);
+ if (succeeded(foldConstantOp(*this, operands)))
+ return getResult();
return {};
}
>From 4b77e6b6b01148539646c4d88a0a571a715704a9 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 28 Jan 2025 14:22:17 +0800
Subject: [PATCH 3/5] update function name and add canonicalize test and
regression test.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 38 +++++++---------
.../VectorToLLVM/vector-to-llvm.mlir | 36 ---------------
mlir/test/Dialect/Vector/canonicalize.mlir | 44 +++++++++++++++++++
.../CPU/extract-insert-fold-constant.mlir | 34 ++++++++++++++
4 files changed, 94 insertions(+), 58 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/extract-insert-fold-constant.mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 389ed19489f7935..7660130ca7cfa7c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1979,33 +1979,27 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
// If the dynamic operands of `extractOp` or `insertOp` is result of
// `constantOp`, then fold it.
-template <typename T>
-static LogicalResult foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
+template <typename OpType, typename AdaptorType>
+static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
+ SmallVectorImpl<Value> &operands) {
auto staticPosition = op.getStaticPosition().vec();
OperandRange dynamicPosition = op.getDynamicPosition();
-
+ ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
// If the dynamic operands is empty, it is returned directly.
if (!dynamicPosition.size())
- return failure();
+ return {};
unsigned index = 0;
- // `opChange` is a flog. If it is true, it means to update `op` in place.
+ // `opChange` is a flag. If it is true, it means to update `op` in place.
bool opChange = false;
for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
if (!ShapedType::isDynamic(staticPosition[i]))
continue;
+ Attribute positionAttr = dynamicPositionAttr[index];
Value position = dynamicPosition[index++];
-
- // If it is a block parameter, proceed to the next iteration.
- if (!position.getDefiningOp()) {
- operands.push_back(position);
- continue;
- }
-
- APInt pos;
- if (matchPattern(position, m_ConstantInt(&pos))) {
+ if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
+ staticPosition[i] = attr.getInt();
opChange = true;
- staticPosition[i] = pos.getSExtValue();
continue;
}
operands.push_back(position);
@@ -2014,12 +2008,12 @@ static LogicalResult foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
if (opChange) {
op.setStaticPosition(staticPosition);
op.getOperation()->setOperands(operands);
- return success();
+ return op.getResult();
}
- return failure();
+ return {};
}
-OpFoldResult ExtractOp::fold(FoldAdaptor) {
+OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
// mismatch).
@@ -2042,8 +2036,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
SmallVector<Value> operands = {getVector()};
- if (succeeded(foldConstantOp(*this, operands)))
- return getResult();
+ if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
+ return val;
return OpFoldResult();
}
@@ -3074,8 +3068,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
SmallVector<Value> operands = {getSource(), getDest()};
- if (succeeded(foldConstantOp(*this, operands)))
- return getResult();
+ if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
+ return val;
return {};
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f8f5f9039bb1469..29bed9aae56827e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4115,39 +4115,3 @@ func.func @step_scalable() -> vector<[4]xindex> {
%0 = vector.step : vector<[4]xindex>
return %0 : vector<[4]xindex>
}
-
-// -----
-
-// CHECK-LABEL: @extract_arith_constnt
-func.func @extract_arith_constnt() -> i32 {
- %v = arith.constant dense<0> : vector<32x1xi32>
- %c_0 = arith.constant 0 : index
- %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
- return %elem : i32
-}
-
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
-// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %{{.*}} = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
-
-// -----
-
-// CHECK-LABEL: @insert_arith_constnt()
-
-func.func @insert_arith_constnt() -> vector<32x1xi32> {
- %v = arith.constant dense<0> : vector<32x1xi32>
- %c_0 = arith.constant 0 : index
- %c_1 = arith.constant 1 : i32
- %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
- return %v_1 : vector<32x1xi32>
-}
-
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
-// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
-// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_2]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[VAL_5]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0eebb6e8d612d41..0f017ffb97223e2 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2979,3 +2979,47 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: func @extract_arith_constnt
+
+func.func @extract_arith_constnt() -> i32 {
+ %c1_i32 = arith.constant 1 : i32
+ return %c1_i32 : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
+// CHECK: return %[[VAL_0]] : i32
+
+// -----
+
+// CHECK-LABEL: func @insert_arith_constnt
+
+func.func @insert_arith_constnt() -> vector<4x1xi32> {
+ %v = arith.constant dense<0> : vector<4x1xi32>
+ %c_0 = arith.constant 0 : index
+ %c_1 = arith.constant 1 : i32
+ %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<4x1xi32>
+ return %v_1 : vector<4x1xi32>
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<{{\[\[}}1], [0], [0], [0]]> : vector<4x1xi32>
+// CHECK: return %[[VAL_0]] : vector<4x1xi32>
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_arith_constnt
+
+func.func @insert_extract_arith_constnt() -> i32 {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %c_1 = arith.constant 1 : index
+ %c_2 = arith.constant 2 : i32
+ %v_1 = vector.insert %c_2, %v[%c_1, %c_1] : i32 into vector<32x1xi32>
+ %ret = vector.extract %v_1[%c_1, %c_1] : i32 from vector<32x1xi32>
+ return %ret : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant 2 : i32
+// CHECK: return %[[VAL_0]] : i32
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/extract-insert-fold-constant.mlir b/mlir/test/Integration/Dialect/Vector/CPU/extract-insert-fold-constant.mlir
new file mode 100644
index 000000000000000..1a99626837da556
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/extract-insert-fold-constant.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+ %v = arith.constant dense<0> : vector<2x2xi32>
+ %c_0 = arith.constant 0 : index
+ %c_1 = arith.constant 1 : index
+ %i32_0 = arith.constant 0 : i32
+ %i32_1 = arith.constant 1 : i32
+ %i32_2 = arith.constant 2 : i32
+ %i32_3 = arith.constant 3 : i32
+ %v_1 = vector.insert %i32_0, %v[%c_0, %c_0] : i32 into vector<2x2xi32>
+ %v_2 = vector.insert %i32_1, %v_1[%c_0, %c_1] : i32 into vector<2x2xi32>
+ %v_3 = vector.insert %i32_2, %v_2[%c_1, %c_0] : i32 into vector<2x2xi32>
+ %v_4 = vector.insert %i32_3, %v_3[%c_1, %c_1] : i32 into vector<2x2xi32>
+ // CHECK: ( ( 0, 1 ), ( 2, 3 ) )
+ vector.print %v_4 : vector<2x2xi32>
+ %v_5 = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
+ // CHECK: 0
+ %i32_4 = vector.extract %v_5[%c_0, %c_0] : i32 from vector<2x2xi32>
+ // CHECK: 1
+ %i32_5 = vector.extract %v_5[%c_0, %c_1] : i32 from vector<2x2xi32>
+ // CHECK: 2
+ %i32_6 = vector.extract %v_5[%c_1, %c_0] : i32 from vector<2x2xi32>
+ // CHECK: 3
+ %i32_7 = vector.extract %v_5[%c_1, %c_1] : i32 from vector<2x2xi32>
+ vector.print %i32_4 : i32
+ vector.print %i32_5 : i32
+ vector.print %i32_6 : i32
+ vector.print %i32_7 : i32
+ return
+}
>From a2a5c930b7dae074d4d49fbccd2a44034bcc690a Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Thu, 30 Jan 2025 10:16:05 +0800
Subject: [PATCH 4/5] update test.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 7 +-
.../VectorToLLVM/vector-to-llvm.mlir | 82 ++++++++++---------
mlir/test/Dialect/Vector/canonicalize.mlir | 22 ++---
3 files changed, 57 insertions(+), 54 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 28274bfd06a29c3..478d5737364ef71 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1986,17 +1986,20 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return fromElementsOp.getElements()[flatIndex];
}
-// If the dynamic operands of `extractOp` or `insertOp` is result of
+// If the dynamic indices of `extractOp` or `insertOp` are result of
// `constantOp`, then fold it.
template <typename OpType, typename AdaptorType>
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
SmallVectorImpl<Value> &operands) {
- auto staticPosition = op.getStaticPosition().vec();
+ std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
OperandRange dynamicPosition = op.getDynamicPosition();
ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
+
// If the dynamic operands is empty, it is returned directly.
if (!dynamicPosition.size())
return {};
+
+ // `index` is used to iterate over the `dynamicPosition`.
unsigned index = 0;
// `opChange` is a flag. If it is true, it means to update `op` in place.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 6d9fc5c10857aed..bce5d9bb67c9c62 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1473,6 +1473,26 @@ func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
// CHECK: return %[[T3]] : index
+
+// -----
+
+func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<32x1xi32>) -> i32 {
+ %0 = arith.constant 0 : index
+ %1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
+ return %1 : i32
+}
+
+// Compile-time if the indices of extractOp if constants, the constants will be collapsed,
+// the constants are folded away, hence the lowering works.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
+// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[RES:.*]] = llvm.extractelement %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
+// CHECK: return %[[RES]] : i32
+
// -----
//===----------------------------------------------------------------------===//
@@ -1726,6 +1746,29 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
// -----
+func.func @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : i32
+ %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
+ return %res : vector<4x1xi32>
+}
+
+// Compile-time if the indices of insertOp if constants, the constants will be collapsed,
+// the constants are folded away, hence the lowering works.
+
+// CHECK-LABEL: @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
+// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x1xi32> to !llvm.array<4 x vector<1xi32>>
+// CHECK: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC_1:.*]] = llvm.insertelement %[[C1]], %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
+// CHECK: %[[VEC_2:.*]] = llvm.insertvalue %[[VEC_1]], %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[VEC_2]] : !llvm.array<4 x vector<1xi32>> to vector<4x1xi32>
+// CHECK: return %[[RES]] : vector<4x1xi32>
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.type_cast
//
@@ -4125,42 +4168,3 @@ func.func @step_scalable() -> vector<[4]xindex> {
%0 = vector.step : vector<[4]xindex>
return %0 : vector<[4]xindex>
}
-
-// -----
-
-// CHECK-LABEL: func @fold_extract_constant_indices
-
-func.func @fold_extract_constant_indices(%arg : vector<32x1xi32>) -> i32 {
- %0 = arith.constant 0 : index
- %1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
- return %1 : i32
-}
-
-// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
-// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<32 x vector<1xi32>>
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %[[RES:.*]] = llvm.extractelement %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
-// CHECK: return %[[RES]] : i32
-
-// -----
-
-// CHECK-LABEL: func @fold_insert_constant_indices
-
-func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
- %0 = arith.constant 0 : index
- %1 = arith.constant 1 : i32
- %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
- return %res : vector<4x1xi32>
-}
-
-
-// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
-// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x1xi32> to !llvm.array<4 x vector<1xi32>>
-// CHECK: %[[C1:.*]] = arith.constant 1 : i32
-// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %[[VEC_1:.*]] = llvm.insertelement %[[C1]], %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
-// CHECK: %[[VEC_2:.*]] = llvm.insertvalue %[[VEC_1]], %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
-// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[VEC_2]] : !llvm.array<4 x vector<1xi32>> to vector<4x1xi32>
-// CHECK: return %[[RES]] : vector<4x1xi32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b940ff27d67d975..965777f19457977 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3045,30 +3045,26 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
// -----
-// CHECK-LABEL: func @fold_extract_constant_indices
-
+// CHECK-LABEL: @fold_extract_constant_indices
+// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
+// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>
+// CHECK: return %[[RES]] : i32
func.func @fold_extract_constant_indices(%arg : vector<32x1xi32>) -> i32 {
%0 = arith.constant 0 : index
%1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
return %1 : i32
}
-// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
-// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>
-// CHECK: return %[[RES]] : i32
-
// -----
-// CHECK-LABEL: func @fold_insert_constant_indices
-
+// CHECK-LABEL: @fold_insert_constant_indices
+// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
+// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
+// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, 0] : i32 into vector<4x1xi32>
+// CHECK: return %[[RES]] : vector<4x1xi32>
func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : i32
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
return %res : vector<4x1xi32>
}
-
-// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
-// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
-// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, 0] : i32 into vector<4x1xi32>
-// CHECK: return %[[RES]] : vector<4x1xi32>
>From 7a469d47a57f3004e65d1f00068838ec89653d2b Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Wed, 5 Feb 2025 01:49:50 +0800
Subject: [PATCH 5/5] update comment.
---
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index bce5d9bb67c9c62..1b1d7660578bd24 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1482,8 +1482,8 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg :
return %1 : i32
}
-// Compile-time if the indices of extractOp if constants, the constants will be collapsed,
-// the constants are folded away, hence the lowering works.
+// At compile time, since the indices of extractOp are constants,
+// they will be collapsed and folded away; therefore, the lowering works.
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
@@ -1753,8 +1753,8 @@ func.func @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg :
return %res : vector<4x1xi32>
}
-// Compile-time if the indices of insertOp if constants, the constants will be collapsed,
-// the constants are folded away, hence the lowering works.
+// At compile time, since the indices of insertOp are constants,
+// they will be collapsed and folded away; therefore, the lowering works.
// CHECK-LABEL: @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
More information about the Mlir-commits
mailing list