[Mlir-commits] [mlir] 00e3566 - [mlir][arith] Add arith.constant materialization helper

Rahul Kayaith llvmlistbot at llvm.org
Thu Apr 20 13:32:01 PDT 2023


Author: Rahul Kayaith
Date: 2023-04-20T16:31:52-04:00
New Revision: 00e3566d6c98f7c531be5140a614ca7fb3cc03a1

URL: https://github.com/llvm/llvm-project/commit/00e3566d6c98f7c531be5140a614ca7fb3cc03a1
DIFF: https://github.com/llvm/llvm-project/commit/00e3566d6c98f7c531be5140a614ca7fb3cc03a1.diff

LOG: [mlir][arith] Add arith.constant materialization helper

This adds `arith::ConstantOp::materialize`, which builds a constant from
an attribute and type only if it would result in a valid op. This is
useful for dialect `materializeConstant` hooks, and allows for removing
the previous `Attribute, Type` builder which was only used during
materialization.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D148491

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/Math/IR/MathOps.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 35f4d0761db93..7b7b30e84ce2d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -178,15 +178,15 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
   // splitting the Standard dialect.
   let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
 
-  let builders = [
-    OpBuilder<(ins "Attribute":$value, "Type":$type),
-    [{ build($_builder, $_state, type, value); }]>,
-  ];
-
   let extraClassDeclaration = [{
     /// Whether the constant op can be constructed with a particular value and
     /// type.
     static bool isBuildableWith(Attribute value, Type type);
+
+    /// Build the constant op with `value` and `type` if possible, otherwise
+    /// returns null.
+    static ConstantOp materialize(OpBuilder &builder, Attribute value,
+                                  Type type, Location loc);
   }];
 
   let hasFolder = 1;

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 0ffb258f3f488..bd9811095356b 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -225,7 +225,7 @@ void AffineDialect::initialize() {
 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  return builder.create<arith::ConstantOp>(loc, type, value);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 /// A utility function to check if a value is defined at the top level of an

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index 0a7b2c4f0b561..7f2d79355fe0f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -49,5 +49,5 @@ void arith::ArithDialect::initialize() {
 Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
                                                     Attribute value, Type type,
                                                     Location loc) {
-  return builder.create<arith::ConstantOp>(loc, value, type);
+  return ConstantOp::materialize(builder, value, type, loc);
 }

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index e203dbc847339..d3ca1987a1707 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -185,6 +185,13 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
 }
 
+ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
+                                          Type type, Location loc) {
+  if (isBuildableWith(value, type))
+    return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
+  return nullptr;
+}
+
 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
 
 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,

diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
index ec56d93c6f156..0a2691a113f71 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
@@ -36,9 +36,7 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder,
     return builder.create<complex::ConstantOp>(loc, type,
                                                value.cast<ArrayAttr>());
   }
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, type, value);
-  return nullptr;
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 #define GET_ATTRDEF_CLASSES

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07ddc02c00f47..f9265b43eb379 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2109,5 +2109,5 @@ void LinalgDialect::getCanonicalizationPatterns(
 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  return builder.create<arith::ConstantOp>(loc, type, value);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f3c8c5c06a087..57e6e2a6c81e4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1688,8 +1688,8 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
       }
 
       // Create a constant scalar value from the splat constant.
-      Value scalarConstant = rewriter.create<arith::ConstantOp>(
-          def->getLoc(), constantAttr, constantAttr.getType());
+      Value scalarConstant =
+          rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
 
       SmallVector<Value> outputOperands = genericOp.getOutputs();
       auto fusedOp = rewriter.create<GenericOp>(

diff  --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 5c93f9f7017c9..ae9dc08c745b4 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -522,5 +522,5 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,
                                                   Location loc) {
-  return builder.create<arith::ConstantOp>(loc, value, type);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ee47547a1775b..a828fb6a7a679 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -74,9 +74,7 @@ struct Wrapper {
 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, value, type);
-  return nullptr;
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 9af32fb5afe74..3417388d0bb92 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -154,9 +154,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
   if (type.isa<WitnessType>())
     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, type, value);
-  return nullptr;
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0eca1843ea19f..99382a375c1f4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -38,8 +38,8 @@ using namespace mlir::tensor;
 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, value, type);
+  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
+    return op;
   if (complex::ConstantOp::isBuildableWith(value, type))
     return builder.create<complex::ConstantOp>(loc, type,
                                                value.cast<ArrayAttr>());

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 89ca099112309..e214820c2f47f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -280,7 +280,7 @@ void VectorDialect::initialize() {
 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  return builder.create<arith::ConstantOp>(loc, type, value);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 IntegerType vector::getVectorSubscriptType(Builder &builder) {


        


More information about the Mlir-commits mailing list