[Mlir-commits] [mlir] [mlir][vector]add foldConstantOp fold function and apply it to extractOp and insertOp. (PR #124399)
lonely eagle
llvmlistbot at llvm.org
Sat Jan 25 01:00:55 PST 2025
https://github.com/linuxlonelyeagle created https://github.com/llvm/llvm-project/pull/124399
see https://github.com/llvm/llvm-project/pull/121631
Ping: @Groverkss @ftynse I took a more clever approach to implement it, which I think should be good.
>From 3ff7e5924887d095b52a180cb133fda446356799 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 25 Jan 2025 16:55:50 +0800
Subject: [PATCH] add foldConstantOp fold function and apply it to extractOp
and insertOp.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 45 +++++++++++++++++++
.../VectorToLLVM/vector-to-llvm.mlir | 36 +++++++++++++++
.../Vector/vector-warp-distribute.mlir | 3 +-
3 files changed, 82 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3fbfcb4979b495..5021b097fc5ef6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -1977,6 +1978,46 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return fromElementsOp.getElements()[flatIndex];
}
+// If the dynamic operands of `extractOp` or `insertOp` is result of
+// `constantOp`, then fold it.
+template <typename T>
+static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
+ auto staticPosition = op.getStaticPosition().vec();
+ OperandRange dynamicPosition = op.getDynamicPosition();
+
+ // If the dynamic operands is empty, it is returned directly.
+ if (!dynamicPosition.size())
+ return;
+ unsigned index = 0;
+
+ // `opChange` is a flog. If it is true, it means to update `op` in place.
+ bool opChange = false;
+ for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
+ if (!ShapedType::isDynamic(staticPosition[i]))
+ continue;
+ Value position = dynamicPosition[index++];
+
+ // If it is a block parameter, proceed to the next iteration.
+ if (!position.getDefiningOp()) {
+ operands.push_back(position);
+ continue;
+ }
+
+ if (auto constantOp =
+ mlir::dyn_cast<arith::ConstantIndexOp>(position.getDefiningOp())) {
+ opChange = true;
+ staticPosition[i] = constantOp.value();
+ continue;
+ }
+ operands.push_back(position);
+ }
+
+ if (opChange) {
+ op.setStaticPosition(staticPosition);
+ op.getOperation()->setOperands(operands);
+ }
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -1999,6 +2040,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
+ SmallVector<Value> operands = {getVector()};
+ foldConstantOp(*this, operands);
return OpFoldResult();
}
@@ -3028,6 +3071,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
+ SmallVector<Value> operands = {getSource(), getDest()};
+ foldConstantOp(*this, operands);
return {};
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 29bed9aae56827..f8f5f9039bb146 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4115,3 +4115,39 @@ func.func @step_scalable() -> vector<[4]xindex> {
%0 = vector.step : vector<[4]xindex>
return %0 : vector<[4]xindex>
}
+
+// -----
+
+// CHECK-LABEL: @extract_arith_constnt
+func.func @extract_arith_constnt() -> i32 {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
+ return %elem : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %{{.*}} = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
+
+// -----
+
+// CHECK-LABEL: @insert_arith_constnt()
+
+func.func @insert_arith_constnt() -> vector<32x1xi32> {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %c_1 = arith.constant 1 : i32
+ %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
+ return %v_1 : vector<32x1xi32>
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_2]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[VAL_5]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index dbe0b39422369c..38771f25934495 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
// CHECK-PROP-LABEL: func.func @vector_extract_1d(
// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32
-// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: gpu.yield %[[V]] : vector<64xf32>
// CHECK-PROP: }
-// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
// CHECK-PROP: return %[[SHUFFLED]] : f32
func.func @vector_extract_1d(%laneid: index) -> (f32) {
More information about the Mlir-commits
mailing list