[Mlir-commits] [mlir] 7ae78a6 - [mlir][vector]add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp. (#124399)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 10 08:22:03 PST 2025
Author: lonely eagle
Date: 2025-02-11T00:21:59+08:00
New Revision: 7ae78a6cdb6ce9ad1534ed10519649fb3d47aca9
URL: https://github.com/llvm/llvm-project/commit/7ae78a6cdb6ce9ad1534ed10519649fb3d47aca9
DIFF: https://github.com/llvm/llvm-project/commit/7ae78a6cdb6ce9ad1534ed10519649fb3d47aca9.diff
LOG: [mlir][vector]add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp. (#124399)
add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp.
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b4a5461f4405dcf..94f9ead9e16653a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1989,6 +1989,45 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return fromElementsOp.getElements()[flatIndex];
}
+/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
+/// then fold it.
+template <typename OpType, typename AdaptorType>
+static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
+ SmallVectorImpl<Value> &operands) {
+ std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
+ OperandRange dynamicPosition = op.getDynamicPosition();
+ ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
+
+ // If the dynamic operands is empty, it is returned directly.
+ if (!dynamicPosition.size())
+ return {};
+
+ // `index` is used to iterate over the `dynamicPosition`.
+ unsigned index = 0;
+
+ // `opChange` is a flag. If it is true, it means to update `op` in place.
+ bool opChange = false;
+ for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
+ if (!ShapedType::isDynamic(staticPosition[i]))
+ continue;
+ Attribute positionAttr = dynamicPositionAttr[index];
+ Value position = dynamicPosition[index++];
+ if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
+ staticPosition[i] = attr.getInt();
+ opChange = true;
+ continue;
+ }
+ operands.push_back(position);
+ }
+
+ if (opChange) {
+ op.setStaticPosition(staticPosition);
+ op.getOperation()->setOperands(operands);
+ return op.getResult();
+ }
+ return {};
+}
+
/// Fold an insert or extract operation into an poison value when a poison index
/// is found at any dimension of the static position.
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
@@ -2035,6 +2074,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
+ SmallVector<Value> operands = {getVector()};
+ if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
+ return val;
return OpFoldResult();
}
@@ -3094,6 +3136,9 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
+ SmallVector<Value> operands = {getSource(), getDest()};
+ if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
+ return val;
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index d319b9043b4b8b6..d261327ec005f62 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -530,6 +530,25 @@ func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
// -----
+func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<32x1xi32>) -> i32 {
+ %0 = arith.constant 0 : index
+ %1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
+ return %1 : i32
+}
+
+// At compile time, since the indices of extractOp are constants,
+// they will be collapsed and folded away; therefore, the lowering works.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
+// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[RES:.*]] = llvm.extractelement %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
+// CHECK: return %[[RES]] : i32
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.insertelement
//===----------------------------------------------------------------------===//
@@ -781,6 +800,29 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
// -----
+func.func @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : i32
+ %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
+ return %res : vector<4x1xi32>
+}
+
+// At compile time, since the indices of insertOp are constants,
+// they will be collapsed and folded away; therefore, the lowering works.
+
+// CHECK-LABEL: @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
+// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x1xi32> to !llvm.array<4 x vector<1xi32>>
+// CHECK: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC_1:.*]] = llvm.insertelement %[[C1]], %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
+// CHECK: %[[VEC_2:.*]] = llvm.insertvalue %[[VEC_1]], %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[VEC_2]] : !llvm.array<4 x vector<1xi32>> to vector<4x1xi32>
+// CHECK: return %[[RES]] : vector<4x1xi32>
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.type_cast
//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a74e562ad2f68d7..93581cbfbe5e4bf 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3171,3 +3171,29 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @fold_extract_constant_indices
+// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
+// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>
+// CHECK: return %[[RES]] : i32
+func.func @fold_extract_constant_indices(%arg : vector<32x1xi32>) -> i32 {
+ %0 = arith.constant 0 : index
+ %1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @fold_insert_constant_indices
+// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
+// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
+// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, 0] : i32 into vector<4x1xi32>
+// CHECK: return %[[RES]] : vector<4x1xi32>
+func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : i32
+ %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
+ return %res : vector<4x1xi32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index dbe0b39422369ce..38771f25934495f 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
// CHECK-PROP-LABEL: func.func @vector_extract_1d(
// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32
-// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: gpu.yield %[[V]] : vector<64xf32>
// CHECK-PROP: }
-// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
// CHECK-PROP: return %[[SHUFFLED]] : f32
func.func @vector_extract_1d(%laneid: index) -> (f32) {
More information about the Mlir-commits
mailing list