[Mlir-commits] [mlir] d13940e - [mlir][Vector] Teach how to materialize UB constant to Vector (#125596)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 4 11:18:35 PST 2025
Author: Diego Caballero
Date: 2025-02-04T11:18:30-08:00
New Revision: d13940ee263ff50b7a71e21424913cc0266bf9d4
URL: https://github.com/llvm/llvm-project/commit/d13940ee263ff50b7a71e21424913cc0266bf9d4
DIFF: https://github.com/llvm/llvm-project/commit/d13940ee263ff50b7a71e21424913cc0266bf9d4.diff
LOG: [mlir][Vector] Teach how to materialize UB constant to Vector (#125596)
This PR adds support for UB constant materialization (i.e., generating
`ub::PoisonOp` to `VectorDialect::materializeConstant`. This was the
reason why the vector folders generating poison didn't work.
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 93f89eda2da5a6..2ec1b97f2f241d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -437,6 +437,9 @@ void VectorDialect::initialize() {
Operation *VectorDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
+ if (isa<ub::PoisonAttrInterface>(value))
+ return value.getDialect().materializeConstant(builder, value, type, loc);
+
return arith::ConstantOp::materialize(builder, value, type, loc);
}
@@ -2273,20 +2276,6 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
return success();
}
-/// Fold an insert or extract operation into an poison value when a poison index
-/// is found at any dimension of the static position.
-template <typename OpTy>
-LogicalResult
-canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) {
- if (auto poisonAttr = foldPoisonIndexInsertExtractOp(
- op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) {
- rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr);
- return success();
- }
-
- return failure();
-}
-
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2295,7 +2284,6 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
- results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
}
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -3068,7 +3056,6 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
InsertOpConstantFolder>(context);
- results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>);
}
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7df6defc0f202f..9a6337f14ace33 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1250,13 +1250,13 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 {
// -----
-func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 {
+func.func @extract_scalar_from_vec_1d_f32_poison_idx(%arg0: vector<16xf32>) -> f32 {
%0 = vector.extract %arg0[-1]: f32 from vector<16xf32>
return %0 : f32
}
-// CHECK-LABEL: @extract_poison_idx
-// CHECK: %[[IDX:.*]] = llvm.mlir.constant(-1 : i64) : i64
-// CHECK: llvm.extractelement {{.*}}[%[[IDX]] : i64] : vector<16xf32>
+// CHECK-LABEL: @extract_scalar_from_vec_1d_f32_poison_idx
+// CHECK: %[[UB:.*]] = ub.poison : f32
+// CHECK: return %[[UB]] : f32
// -----
@@ -1335,6 +1335,16 @@ func.func @extract_vec_2d_from_vec_3d_f32(%arg0: vector<4x3x16xf32>) -> vector<3
// -----
+func.func @extract_vec_2d_from_vec_3d_f32_poison_idx(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
+ %0 = vector.extract %arg0[-1]: vector<3x16xf32> from vector<4x3x16xf32>
+ return %0 : vector<3x16xf32>
+}
+// CHECK-LABEL: @extract_vec_2d_from_vec_3d_f32_poison_idx
+// CHECK: %[[UB:.*]] = ub.poison : vector<3x16xf32>
+// CHECK: return %[[UB]] : vector<3x16xf32>
+
+// -----
+
func.func @extract_vec_2d_from_vec_3d_f32_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<3x[16]xf32> {
%0 = vector.extract %arg0[0]: vector<3x[16]xf32> from vector<4x3x[16]xf32>
return %0 : vector<3x[16]xf32>
More information about the Mlir-commits
mailing list