[Mlir-commits] [mlir] 8e123ca - [mlir:Standard] Remove support for creating a `unit` ConstantOp
River Riddle
llvmlistbot at llvm.org
Wed Feb 2 14:46:11 PST 2022
Author: River Riddle
Date: 2022-02-02T14:45:12-08:00
New Revision: 8e123ca65f5f9286e59f2c79184d01673c87aa42
URL: https://github.com/llvm/llvm-project/commit/8e123ca65f5f9286e59f2c79184d01673c87aa42
DIFF: https://github.com/llvm/llvm-project/commit/8e123ca65f5f9286e59f2c79184d01673c87aa42.diff
LOG: [mlir:Standard] Remove support for creating a `unit` ConstantOp
This is completely unused upstream, and does not really have well defined semantics
on what this is supposed to do/how this fits into the ecosystem. Given that, as part of
splitting up the standard dialect it's best to just remove this behavior, instead of try
to awkwardly fit it somewhere upstream. Downstream users are encouraged to
define their own operations that clearly can define the semantics of this.
This also uncovered several lingering uses of ConstantOp that weren't
updated to use arith::ConstantOp, and worked during conversions because
the constant was removed/converted into something else before
verification.
See https://llvm.discourse.group/t/standard-dialect-the-final-chapter/ for more discussion.
Differential Revision: https://reviews.llvm.org/D118654
Added:
Modified:
flang/lib/Optimizer/Builder/Character.cpp
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
mlir/lib/Target/Cpp/TranslateToCpp.cpp
mlir/test/Dialect/Standard/invalid.mlir
mlir/test/IR/core-ops.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Builder/Character.cpp b/flang/lib/Optimizer/Builder/Character.cpp
index 87faa3b42c449..e4719133f3fa0 100644
--- a/flang/lib/Optimizer/Builder/Character.cpp
+++ b/flang/lib/Optimizer/Builder/Character.cpp
@@ -72,7 +72,7 @@ LLVM_ATTRIBUTE_UNUSED static bool needToMaterialize(mlir::Value str) {
/// Unwrap integer constant from mlir::Value.
static llvm::Optional<std::int64_t> getIntIfConstant(mlir::Value value) {
if (auto *definingOp = value.getDefiningOp())
- if (auto cst = mlir::dyn_cast<mlir::ConstantOp>(definingOp))
+ if (auto cst = mlir::dyn_cast<mlir::arith::ConstantOp>(definingOp))
if (auto intAttr = cst.getValue().dyn_cast<mlir::IntegerAttr>())
return intAttr.getInt();
return {};
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 9efe4ceb21473..2aca33eda3c46 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -376,23 +376,16 @@ def ConstantOp : Std_Op<"constant",
operation ::= ssa-id `=` `std.constant` attribute-value `:` type
```
- The `constant` operation produces an SSA value equal to some constant
- specified by an attribute. This is the way that MLIR uses to form simple
- integer and floating point constants, as well as more exotic things like
- references to functions and tensor/vector constants.
+ The `constant` operation produces an SSA value from a symbol reference to a
+ `builtin.func` operation
Example:
```mlir
- // Complex constant
- %1 = constant [1.0 : f32, 1.0 : f32] : complex<f32>
-
// Reference to function @myfn.
%2 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32>
// Equivalent generic forms
- %1 = "std.constant"() {value = [1.0 : f32, 1.0 : f32] : complex<f32>}
- : () -> complex<f32>
%2 = "std.constant"() {value = @myfn}
: () -> ((tensor<16xf32>, f32) -> tensor<16xf32>)
```
@@ -403,15 +396,9 @@ def ConstantOp : Std_Op<"constant",
([rationale](../Rationale/Rationale.md#multithreading-the-compiler)).
}];
- let arguments = (ins AnyAttr:$value);
+ let arguments = (ins FlatSymbolRefAttr:$value);
let results = (outs AnyType);
-
- let builders = [
- OpBuilder<(ins "Attribute":$value),
- [{ build($_builder, $_state, value.getType(), value); }]>,
- OpBuilder<(ins "Attribute":$value, "Type":$type),
- [{ build($_builder, $_state, type, value); }]>,
- ];
+ let assemblyFormat = "attr-dict $value `:` type(results)";
let extraClassDeclaration = [{
/// Returns true if a constant operation can be built with the given value
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index b2e18ab8196f4..04c51422ed115 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -435,31 +435,19 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
LogicalResult
matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // If constant refers to a function, convert it to "addressof".
- if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
- auto type = typeConverter->convertType(op.getResult().getType());
- if (!type || !LLVM::isCompatibleType(type))
- return rewriter.notifyMatchFailure(op, "failed to convert result type");
-
- auto newOp = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type,
- symbolRef.getValue());
- for (const NamedAttribute &attr : op->getAttrs()) {
- if (attr.getName().strref() == "value")
- continue;
- newOp->setAttr(attr.getName(), attr.getValue());
- }
- rewriter.replaceOp(op, newOp->getResults());
- return success();
+ auto type = typeConverter->convertType(op.getResult().getType());
+ if (!type || !LLVM::isCompatibleType(type))
+ return rewriter.notifyMatchFailure(op, "failed to convert result type");
+
+ auto newOp =
+ rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
+ for (const NamedAttribute &attr : op->getAttrs()) {
+ if (attr.getName().strref() == "value")
+ continue;
+ newOp->setAttr(attr.getName(), attr.getValue());
}
-
- // Calling into other scopes (non-flat reference) is not supported in LLVM.
- if (op.getValue().isa<SymbolRefAttr>())
- return rewriter.notifyMatchFailure(
- op, "referring to a symbol outside of the current module");
-
- return LLVM::detail::oneToOneRewrite(
- op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
- *getTypeConverter(), rewriter);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
}
};
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index e1c91fbbc1d98..74d6d42e2b9b0 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -291,7 +291,7 @@ static ParallelComputeFunction createParallelComputeFunction(
return llvm::to_vector(
llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value {
if (IntegerAttr attr = std::get<1>(tuple))
- return b.create<ConstantOp>(attr);
+ return b.create<arith::ConstantOp>(attr);
return std::get<0>(tuple);
}));
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 19e3c0318f574..32fd370012c44 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1576,7 +1576,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
: DenseElementsAttr::get(outputType, intOutputValues);
- rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
return success();
}
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 1237f0b47cf72..d47d6ead0273e 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -145,7 +145,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
// Stitch results together into one large vector.
Type resultEltType = results[0].getType().cast<VectorType>().getElementType();
Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
- Value result = builder.create<ConstantOp>(
+ Value result = builder.create<arith::ConstantOp>(
resultExpandedType, builder.getZeroAttr(resultExpandedType));
for (int64_t i = 0; i < maxLinearIndex; ++i)
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 65e72293ed3f0..bf35625adb625 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -115,7 +115,10 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
Location loc) {
if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, type, value);
- return builder.create<ConstantOp>(loc, type, value);
+ if (ConstantOp::isBuildableWith(value, type))
+ return builder.create<ConstantOp>(loc, type,
+ value.cast<FlatSymbolRefAttr>());
+ return nullptr;
}
//===----------------------------------------------------------------------===//
@@ -562,97 +565,35 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
// ConstantOp
//===----------------------------------------------------------------------===//
-static void print(OpAsmPrinter &p, ConstantOp &op) {
- p << " ";
- p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
-
- if (op->getAttrs().size() > 1)
- p << ' ';
- p << op.getValue();
-
- // If the value is a symbol reference, print a trailing type.
- if (op.getValue().isa<SymbolRefAttr>())
- p << " : " << op.getType();
-}
-
-static ParseResult parseConstantOp(OpAsmParser &parser,
- OperationState &result) {
- Attribute valueAttr;
- if (parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseAttribute(valueAttr, "value", result.attributes))
- return failure();
-
- // If the attribute is a symbol reference, then we expect a trailing type.
- Type type;
- if (!valueAttr.isa<SymbolRefAttr>())
- type = valueAttr.getType();
- else if (parser.parseColonType(type))
- return failure();
-
- // Add the attribute type to the list.
- return parser.addTypeToList(type, result.types);
-}
-
-/// The constant op requires an attribute, and furthermore requires that it
-/// matches the return type.
LogicalResult ConstantOp::verify() {
- auto value = getValue();
- if (!value)
- return emitOpError("requires a 'value' attribute");
-
+ StringRef fnName = getValue();
Type type = getType();
- if (!value.getType().isa<NoneType>() && type != value.getType())
- return emitOpError() << "requires attribute's type (" << value.getType()
- << ") to match op's return type (" << type << ")";
-
- if (type.isa<FunctionType>()) {
- auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
- if (!fnAttr)
- return emitOpError("requires 'value' to be a function reference");
-
- // Try to find the referenced function.
- auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
- fnAttr.getValue());
- if (!fn)
- return emitOpError() << "reference to undefined function '"
- << fnAttr.getValue() << "'";
-
- // Check that the referenced function has the correct type.
- if (fn.getType() != type)
- return emitOpError("reference to function with mismatched type");
- return success();
- }
+ // Try to find the referenced function.
+ auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
+ if (!fn)
+ return emitOpError() << "reference to undefined function '" << fnName
+ << "'";
- if (type.isa<NoneType>() && value.isa<UnitAttr>())
- return success();
+ // Check that the referenced function has the correct type.
+ if (fn.getType() != type)
+ return emitOpError("reference to function with mismatched type");
- return emitOpError("unsupported 'value' attribute: ") << value;
+ return success();
}
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
- return getValue();
+ return getValueAttr();
}
void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
- Type type = getType();
- if (type.isa<FunctionType>()) {
- setNameFn(getResult(), "f");
- } else {
- setNameFn(getResult(), "cst");
- }
+ setNameFn(getResult(), "f");
}
-/// Returns true if a constant operation can be built with the given value and
-/// result type.
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
- // SymbolRefAttr can only be used with a function type.
- if (value.isa<SymbolRefAttr>())
- return type.isa<FunctionType>();
- // Otherwise, this must be a UnitAttr.
- return value.isa<UnitAttr>() && type.isa<NoneType>();
+ return value.isa<FlatSymbolRefAttr>() && type.isa<FunctionType>();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index 495de25662dbc..52b52763b0dc7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -307,7 +307,7 @@ struct TwoDimMultiReductionToReduction
return failure();
auto loc = multiReductionOp.getLoc();
- Value result = rewriter.create<ConstantOp>(
+ Value result = rewriter.create<arith::ConstantOp>(
loc, multiReductionOp.getDestType(),
rewriter.getZeroAttr(multiReductionOp.getDestType()));
int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index a332029a4cf81..5d7ef65fcad2e 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -232,7 +232,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
static LogicalResult printOperation(CppEmitter &emitter,
mlir::ConstantOp constantOp) {
Operation *operation = constantOp.getOperation();
- Attribute value = constantOp.getValue();
+ Attribute value = constantOp.getValueAttr();
return printConstantOp(emitter, operation, value);
}
diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 836158dd2160b..e9359936a1962 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
func @unsupported_attribute() {
- // expected-error @+1 {{unsupported 'value' attribute: "" : index}}
+ // expected-error @+1 {{invalid kind of attribute specified}}
%0 = constant "" : index
return
}
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index fefe7387f284a..55280b2ac8b8c 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -99,9 +99,6 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
// CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32>
%70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32>
- // CHECK: = constant unit
- %73 = constant unit
-
// CHECK: arith.constant true
%74 = arith.constant true
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 34dd14176b45d..53661511ee324 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -578,7 +578,7 @@ struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
LogicalResult matchAndRewrite(ILLegalOpG op,
PatternRewriter &rewriter) const final {
IntegerAttr attr = rewriter.getI32IntegerAttr(0);
- Value val = rewriter.create<ConstantOp>(op->getLoc(), attr);
+ Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
return success();
};
More information about the Mlir-commits
mailing list