[Mlir-commits] [mlir] 00e3566 - [mlir][arith] Add arith.constant materialization helper
Rahul Kayaith
llvmlistbot at llvm.org
Thu Apr 20 13:32:01 PDT 2023
Author: Rahul Kayaith
Date: 2023-04-20T16:31:52-04:00
New Revision: 00e3566d6c98f7c531be5140a614ca7fb3cc03a1
URL: https://github.com/llvm/llvm-project/commit/00e3566d6c98f7c531be5140a614ca7fb3cc03a1
DIFF: https://github.com/llvm/llvm-project/commit/00e3566d6c98f7c531be5140a614ca7fb3cc03a1.diff
LOG: [mlir][arith] Add arith.constant materialization helper
This adds `arith::ConstantOp::materialize`, which builds a constant from
an attribute and type only if it would result in a valid op. This is
useful for dialect `materializeConstant` hooks, and allows for removing
the previous `Attribute, Type` builder which was only used during
materialization.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D148491
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Math/IR/MathOps.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 35f4d0761db93..7b7b30e84ce2d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -178,15 +178,15 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
// splitting the Standard dialect.
let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
- let builders = [
- OpBuilder<(ins "Attribute":$value, "Type":$type),
- [{ build($_builder, $_state, type, value); }]>,
- ];
-
let extraClassDeclaration = [{
/// Whether the constant op can be constructed with a particular value and
/// type.
static bool isBuildableWith(Attribute value, Type type);
+
+ /// Build the constant op with `value` and `type` if possible, otherwise
+ /// returns null.
+ static ConstantOp materialize(OpBuilder &builder, Attribute value,
+ Type type, Location loc);
}];
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 0ffb258f3f488..bd9811095356b 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -225,7 +225,7 @@ void AffineDialect::initialize() {
Operation *AffineDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
- return builder.create<arith::ConstantOp>(loc, type, value);
+ return arith::ConstantOp::materialize(builder, value, type, loc);
}
/// A utility function to check if a value is defined at the top level of an
diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index 0a7b2c4f0b561..7f2d79355fe0f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -49,5 +49,5 @@ void arith::ArithDialect::initialize() {
Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
- return builder.create<arith::ConstantOp>(loc, value, type);
+ return ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index e203dbc847339..d3ca1987a1707 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -185,6 +185,13 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
}
+ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
+ Type type, Location loc) {
+ if (isBuildableWith(value, type))
+ return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
+ return nullptr;
+}
+
OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
index ec56d93c6f156..0a2691a113f71 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
@@ -36,9 +36,7 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder,
return builder.create<complex::ConstantOp>(loc, type,
value.cast<ArrayAttr>());
}
- if (arith::ConstantOp::isBuildableWith(value, type))
- return builder.create<arith::ConstantOp>(loc, type, value);
- return nullptr;
+ return arith::ConstantOp::materialize(builder, value, type, loc);
}
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07ddc02c00f47..f9265b43eb379 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2109,5 +2109,5 @@ void LinalgDialect::getCanonicalizationPatterns(
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
- return builder.create<arith::ConstantOp>(loc, type, value);
+ return arith::ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f3c8c5c06a087..57e6e2a6c81e4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1688,8 +1688,8 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
}
// Create a constant scalar value from the splat constant.
- Value scalarConstant = rewriter.create<arith::ConstantOp>(
- def->getLoc(), constantAttr, constantAttr.getType());
+ Value scalarConstant =
+ rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
SmallVector<Value> outputOperands = genericOp.getOutputs();
auto fusedOp = rewriter.create<GenericOp>(
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 5c93f9f7017c9..ae9dc08c745b4 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -522,5 +522,5 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
- return builder.create<arith::ConstantOp>(loc, value, type);
+ return arith::ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ee47547a1775b..a828fb6a7a679 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -74,9 +74,7 @@ struct Wrapper {
Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
- if (arith::ConstantOp::isBuildableWith(value, type))
- return builder.create<arith::ConstantOp>(loc, value, type);
- return nullptr;
+ return arith::ConstantOp::materialize(builder, value, type, loc);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 9af32fb5afe74..3417388d0bb92 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -154,9 +154,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
if (type.isa<WitnessType>())
return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
- if (arith::ConstantOp::isBuildableWith(value, type))
- return builder.create<arith::ConstantOp>(loc, type, value);
- return nullptr;
+ return arith::ConstantOp::materialize(builder, value, type, loc);
}
LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0eca1843ea19f..99382a375c1f4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -38,8 +38,8 @@ using namespace mlir::tensor;
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
- if (arith::ConstantOp::isBuildableWith(value, type))
- return builder.create<arith::ConstantOp>(loc, value, type);
+ if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
+ return op;
if (complex::ConstantOp::isBuildableWith(value, type))
return builder.create<complex::ConstantOp>(loc, type,
value.cast<ArrayAttr>());
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 89ca099112309..e214820c2f47f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -280,7 +280,7 @@ void VectorDialect::initialize() {
Operation *VectorDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
- return builder.create<arith::ConstantOp>(loc, type, value);
+ return arith::ConstantOp::materialize(builder, value, type, loc);
}
IntegerType vector::getVectorSubscriptType(Builder &builder) {
More information about the Mlir-commits
mailing list