[flang-commits] [flang] 60cac0c - [mlir][NFC] Remove deprecated/old build/fold/parser utilities from OpDefinition
River Riddle via flang-commits
flang-commits at lists.llvm.org
Mon Feb 7 19:04:29 PST 2022
Author: River Riddle
Date: 2022-02-07T19:03:58-08:00
New Revision: 60cac0c0816193f3d910cd8bdaebac8e6694a6bd
URL: https://github.com/llvm/llvm-project/commit/60cac0c0816193f3d910cd8bdaebac8e6694a6bd
DIFF: https://github.com/llvm/llvm-project/commit/60cac0c0816193f3d910cd8bdaebac8e6694a6bd.diff
LOG: [mlir][NFC] Remove deprecated/old build/fold/parser utilities from OpDefinition
These have generally been replaced by better ODS functionality, and do not
need to be explicitly provided anymore.
Differential Revision: https://reviews.llvm.org/D119065
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Optimizer/Dialect/FIROps.cpp
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/IR/Operation.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index e79fafad6cd0f..cf20385f46317 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2534,19 +2534,15 @@ def fir_StringLitOp : fir_Op<"string_lit", [NoSideEffect]> {
class fir_ArithmeticOp<string mnemonic, list<Trait> traits = []> :
fir_Op<mnemonic,
!listconcat(traits, [NoSideEffect, SameOperandsAndResultType])>,
- Results<(outs AnyType)> {
- let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);";
-
- let printer = "return printBinaryOp(this->getOperation(), p);";
+ Results<(outs AnyType:$result)> {
+ let assemblyFormat = "operands attr-dict `:` type($result)";
}
class fir_UnaryArithmeticOp<string mnemonic, list<Trait> traits = []> :
fir_Op<mnemonic,
!listconcat(traits, [NoSideEffect, SameOperandsAndResultType])>,
- Results<(outs AnyType)> {
- let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);";
-
- let printer = "return printUnaryOp(this->getOperation(), p);";
+ Results<(outs AnyType:$result)> {
+ let assemblyFormat = "operands attr-dict `:` type($result)";
}
def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index bc91111812a43..33d34b6ad51a8 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3211,26 +3211,6 @@ mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser,
return mlir::success();
}
-/// Generic pretty-printer of a binary operation
-static void printBinaryOp(Operation *op, OpAsmPrinter &p) {
- assert(op->getNumOperands() == 2 && "binary op must have two operands");
- assert(op->getNumResults() == 1 && "binary op must have one result");
-
- p << ' ' << op->getOperand(0) << ", " << op->getOperand(1);
- p.printOptionalAttrDict(op->getAttrs());
- p << " : " << op->getResult(0).getType();
-}
-
-/// Generic pretty-printer of an unary operation
-static void printUnaryOp(Operation *op, OpAsmPrinter &p) {
- assert(op->getNumOperands() == 1 && "unary op must have one operand");
- assert(op->getNumResults() == 1 && "unary op must have one result");
-
- p << ' ' << op->getOperand(0);
- p.printOptionalAttrDict(op->getAttrs());
- p << " : " << op->getResult(0).getType();
-}
-
bool fir::isReferenceLike(mlir::Type type) {
return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() ||
type.isa<fir::PointerType>();
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 78b3af47ae19b..4a40132df963a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -419,8 +419,7 @@ class LLVM_CastOp<string mnemonic, string builderFunc, Type type,
let arguments = (ins type:$arg);
let results = (outs resultType:$res);
let builders = [LLVM_OneResultOpBuilder];
- let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
- let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
+ let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
}
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast",
LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 79ad1ed7f8046..abb8231d7e5b2 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -383,7 +383,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
}];
let hasFolder = 1;
- let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index ec9116c5d803f..c600b31676b3b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -34,6 +34,7 @@ class SPV_ArithmeticBinaryOp<string mnemonic, Type type,
let results = (outs
SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result
);
+ let assemblyFormat = "operands attr-dict `:` type($result)";
}
class SPV_ArithmeticUnaryOp<string mnemonic, Type type,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 625e44681cae8..055835ea48081 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4338,8 +4338,6 @@ class SPV_BinaryOp<string mnemonic, Type resultType, Type operandsType,
SPV_ScalarOrVectorOf<resultType>:$result
);
- let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
- let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
// No additional verification needed in addition to the ODS-generated ones.
let hasVerifier = 0;
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index e0f16032badd1..b9b31f6dd5abb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -21,7 +21,9 @@ class SPV_BitBinaryOp<string mnemonic, list<Trait> traits = []> :
// All the operands type used in bit instructions are SPV_Integer.
SPV_BinaryOp<mnemonic, SPV_Integer, SPV_Integer,
!listconcat(traits,
- [NoSideEffect, SameOperandsAndResultType])>;
+ [NoSideEffect, SameOperandsAndResultType])> {
+ let assemblyFormat = "operands attr-dict `:` type($result)";
+}
class SPV_BitFieldExtractOp<string mnemonic, list<Trait> traits = []> :
SPV_Op<mnemonic, !listconcat(traits,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index ec2c73f03ac25..5048dd10ae575 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -29,9 +29,9 @@ class SPV_CastOp<string mnemonic, Type resultType, Type operandType,
let results = (outs
SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result
);
-
- let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
- let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand) `to` type($result)
+ }];
}
// -----
@@ -85,9 +85,9 @@ def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> {
SPV_ScalarOrVectorOrPtr:$result
);
- let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
- let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
-
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand) `to` type($result)
+ }];
let hasCanonicalizer = 1;
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
index 1532d3ea7133a..f0d5515d2bcf7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
@@ -72,10 +72,6 @@ class SPV_GLSLBinaryOp<string mnemonic, int opcode, Type resultType,
SPV_ScalarOrVectorOf<resultType>:$result
);
- let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
-
- let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
-
let hasVerifier = 0;
}
@@ -83,7 +79,10 @@ class SPV_GLSLBinaryOp<string mnemonic, int opcode, Type resultType,
// return type matches.
class SPV_GLSLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
list<Trait> traits = []> :
- SPV_GLSLBinaryOp<mnemonic, opcode, type, type, traits>;
+ SPV_GLSLBinaryOp<mnemonic, opcode, type, type,
+ traits # [SameOperandsAndResultType]> {
+ let assemblyFormat = "operands attr-dict `:` type($result)";
+}
// Base class for GLSL ternary ops.
class SPV_GLSLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
@@ -100,9 +99,8 @@ class SPV_GLSLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
SPV_ScalarOrVectorOf<type>:$result
);
- let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
-
- let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
+ let parser = [{ return parseOneResultSameOperandTypeOp(parser, result); }];
+ let printer = [{ return printOneResultOp(getOperation(), p); }];
let hasVerifier = 0;
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
index 92fb8a46478a2..0b3c08a510114 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
@@ -71,10 +71,6 @@ class SPV_OCLBinaryOp<string mnemonic, int opcode, Type resultType,
SPV_ScalarOrVectorOf<resultType>:$result
);
- let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
-
- let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
-
let hasVerifier = 0;
}
@@ -82,7 +78,10 @@ class SPV_OCLBinaryOp<string mnemonic, int opcode, Type resultType,
// return type matches.
class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
list<Trait> traits = []> :
- SPV_OCLBinaryOp<mnemonic, opcode, type, type, traits>;
+ SPV_OCLBinaryOp<mnemonic, opcode, type, type,
+ traits # [SameOperandsAndResultType]> {
+ let assemblyFormat = "operands attr-dict `:` type($result)";
+}
// -----
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index ee745e36769aa..f4a499a4bd7bd 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -14,6 +14,7 @@
#define SHAPE_OPS
include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -331,7 +332,9 @@ def Shape_RankOp : Shape_Op<"rank",
}];
}
-def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
+def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
+ DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
+ ]> {
let summary = "Creates a dimension tensor from a shape";
let description = [{
Converts a shape to a 1D integral tensor of extents. The number of elements
@@ -624,7 +627,9 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
}];
}
-def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
+def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [
+ DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
+ ]> {
let summary = "Casts between index types of the shape and standard dialect";
let description = [{
Converts a `shape.size` to a standard index. This operation and its
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index ac51c7bb51ddb..e29b35f211857 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1897,26 +1897,9 @@ class OpInterface
};
//===----------------------------------------------------------------------===//
-// Common Operation Folders/Parsers/Printers
+// CastOpInterface utilities
//===----------------------------------------------------------------------===//
-// These functions are out-of-line implementations of the methods in UnaryOp and
-// BinaryOp, which avoids them being template instantiated/duplicated.
-namespace impl {
-ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
- OperationState &result);
-
-void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
- Value rhs);
-ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
- OperationState &result);
-
-// Prints the given binary `op` in custom assembly form if both the two operands
-// and the result have the same time. Otherwise, prints the generic assembly
-// form.
-void printOneResultOp(Operation *op, OpAsmPrinter &p);
-} // namespace impl
-
// These functions are out-of-line implementations of the methods in
// CastOpInterface, which avoids them being template instantiated/duplicated.
namespace impl {
@@ -1927,20 +1910,6 @@ LogicalResult foldCastInterfaceOp(Operation *op,
/// Attempt to verify the given cast operation.
LogicalResult verifyCastInterfaceOp(
Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
-
-// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
-// need for them, but some older ODS code in `std` still depends on them).
-void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
- Type destType);
-ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
-void printCastOp(Operation *op, OpAsmPrinter &p);
-// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
-// when all uses have been updated. Also, consider adding functionality to
-// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
-// generically.
-Value foldCastOp(Operation *op);
-LogicalResult verifyCastOp(Operation *op,
- function_ref<bool(Type, Type)> areCastCompatible);
} // namespace impl
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index da672dc86b66d..33192e1bb704b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -65,10 +65,6 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext());
}
-LogicalResult memref::CastOp::verify() {
- return impl::verifyCastOp(*this, areCastCompatible);
-}
-
//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index cb476dcb62307..1d060473547d3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -64,6 +64,54 @@ static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
// Common utility functions
//===----------------------------------------------------------------------===//
+static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 2> ops;
+ Type type;
+ // If the operand list is in-between parentheses, then we have a generic form.
+ // (see the fallback in `printOneResultOp`).
+ SMLoc loc = parser.getCurrentLocation();
+ if (!parser.parseOptionalLParen()) {
+ if (parser.parseOperandList(ops) || parser.parseRParen() ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColon() || parser.parseType(type))
+ return failure();
+ auto fnType = type.dyn_cast<FunctionType>();
+ if (!fnType) {
+ parser.emitError(loc, "expected function type");
+ return failure();
+ }
+ if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
+ return failure();
+ result.addTypes(fnType.getResults());
+ return success();
+ }
+ return failure(parser.parseOperandList(ops) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperands(ops, type, result.operands) ||
+ parser.addTypeToList(type, result.types));
+}
+
+static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
+ assert(op->getNumResults() == 1 && "op should have one result");
+
+ // If not all the operand and result types are the same, just use the
+ // generic assembly form to avoid omitting information in printing.
+ auto resultType = op->getResult(0).getType();
+ if (llvm::any_of(op->getOperandTypes(),
+ [&](Type type) { return type != resultType; })) {
+ p.printGenericOp(op, /*printOpName=*/false);
+ return;
+ }
+
+ p << ' ';
+ p.printOperands(op->getOperands());
+ p.printOptionalAttrDict(op->getAttrs());
+ // Now we can output only one type for all operands and the result.
+ p << " : " << resultType;
+}
+
/// Returns true if the given op is a function-like op or nested in a
/// function-like op without a module-like op in the middle.
static bool isNestedInFunctionOpInterface(Operation *op) {
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index a25f6dd7b1cb5..9241e2d12ba22 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1692,7 +1692,7 @@ OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
// `IntegerAttr`s which makes constant folding simple.
if (Attribute arg = operands[0])
return arg;
- return impl::foldCastOp(*this);
+ return OpFoldResult();
}
void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -1700,6 +1700,12 @@ void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<IndexToSizeToIndexCanonicalization>(context);
}
+bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
+}
+
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
@@ -1750,7 +1756,7 @@ LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0])
- return impl::foldCastOp(*this);
+ return OpFoldResult();
Builder builder(getContext());
auto shape = llvm::to_vector<6>(
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
@@ -1759,6 +1765,21 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
return DenseIntElementsAttr::get(type, shape);
}
+bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
+ if (!inputTensor.getElementType().isa<IndexType>() ||
+ inputTensor.getRank() != 1 || !inputTensor.isDynamicDim(0))
+ return false;
+ } else if (!inputs[0].isa<ShapeType>()) {
+ return false;
+ }
+
+ TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
+ return outputTensor && outputTensor.getElementType().isa<IndexType>();
+}
+
//===----------------------------------------------------------------------===//
// ReduceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index e67933790de11..72738971db6d5 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1125,69 +1125,7 @@ bool OpTrait::hasElementwiseMappableTraits(Operation *op) {
}
//===----------------------------------------------------------------------===//
-// BinaryOp implementation
-//===----------------------------------------------------------------------===//
-
-// These functions are out-of-line implementations of the methods in BinaryOp,
-// which avoids them being template instantiated/duplicated.
-
-void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
- Value rhs) {
- assert(lhs.getType() == rhs.getType());
- result.addOperands({lhs, rhs});
- result.types.push_back(lhs.getType());
-}
-
-ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 2> ops;
- Type type;
- // If the operand list is in-between parentheses, then we have a generic form.
- // (see the fallback in `printOneResultOp`).
- SMLoc loc = parser.getCurrentLocation();
- if (!parser.parseOptionalLParen()) {
- if (parser.parseOperandList(ops) || parser.parseRParen() ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColon() || parser.parseType(type))
- return failure();
- auto fnType = type.dyn_cast<FunctionType>();
- if (!fnType) {
- parser.emitError(loc, "expected function type");
- return failure();
- }
- if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
- return failure();
- result.addTypes(fnType.getResults());
- return success();
- }
- return failure(parser.parseOperandList(ops) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(type) ||
- parser.resolveOperands(ops, type, result.operands) ||
- parser.addTypeToList(type, result.types));
-}
-
-void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) {
- assert(op->getNumResults() == 1 && "op should have one result");
-
- // If not all the operand and result types are the same, just use the
- // generic assembly form to avoid omitting information in printing.
- auto resultType = op->getResult(0).getType();
- if (llvm::any_of(op->getOperandTypes(),
- [&](Type type) { return type != resultType; })) {
- p.printGenericOp(op, /*printOpName=*/false);
- return;
- }
-
- p << ' ';
- p.printOperands(op->getOperands());
- p.printOptionalAttrDict(op->getAttrs());
- // Now we can output only one type for all operands and the result.
- p << " : " << resultType;
-}
-
-//===----------------------------------------------------------------------===//
-// CastOp implementation
+// CastOpInterface
//===----------------------------------------------------------------------===//
/// Attempt to fold the given cast operation.
@@ -1232,50 +1170,6 @@ LogicalResult impl::verifyCastInterfaceOp(
return success();
}
-void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source,
- Type destType) {
- result.addOperands(source);
- result.addTypes(destType);
-}
-
-ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::OperandType srcInfo;
- Type srcType, dstType;
- return failure(parser.parseOperand(srcInfo) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(srcType) ||
- parser.resolveOperand(srcInfo, srcType, result.operands) ||
- parser.parseKeywordType("to", dstType) ||
- parser.addTypeToList(dstType, result.types));
-}
-
-void impl::printCastOp(Operation *op, OpAsmPrinter &p) {
- p << ' ' << op->getOperand(0);
- p.printOptionalAttrDict(op->getAttrs());
- p << " : " << op->getOperand(0).getType() << " to "
- << op->getResult(0).getType();
-}
-
-Value impl::foldCastOp(Operation *op) {
- // Identity cast
- if (op->getOperand(0).getType() == op->getResult(0).getType())
- return op->getOperand(0);
- return nullptr;
-}
-
-LogicalResult
-impl::verifyCastOp(Operation *op,
- function_ref<bool(Type, Type)> areCastCompatible) {
- auto opType = op->getOperand(0).getType();
- auto resType = op->getResult(0).getType();
- if (!areCastCompatible(opType, resType))
- return op->emitError("operand type ")
- << opType << " and result type " << resType
- << " are cast incompatible";
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Misc. utils
//===----------------------------------------------------------------------===//
More information about the flang-commits
mailing list