[Mlir-commits] [mlir] 60cac0c - [mlir][NFC] Remove deprecated/old build/fold/parser utilities from OpDefinition

River Riddle llvmlistbot at llvm.org
Mon Feb 7 19:04:27 PST 2022


Author: River Riddle
Date: 2022-02-07T19:03:58-08:00
New Revision: 60cac0c0816193f3d910cd8bdaebac8e6694a6bd

URL: https://github.com/llvm/llvm-project/commit/60cac0c0816193f3d910cd8bdaebac8e6694a6bd
DIFF: https://github.com/llvm/llvm-project/commit/60cac0c0816193f3d910cd8bdaebac8e6694a6bd.diff

LOG: [mlir][NFC] Remove deprecated/old build/fold/parser utilities from OpDefinition

These have generally been replaced by better ODS functionality, and do not
need to be explicitly provided anymore.

Differential Revision: https://reviews.llvm.org/D119065

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Optimizer/Dialect/FIROps.cpp
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/IR/Operation.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index e79fafad6cd0f..cf20385f46317 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2534,19 +2534,15 @@ def fir_StringLitOp : fir_Op<"string_lit", [NoSideEffect]> {
 class fir_ArithmeticOp<string mnemonic, list<Trait> traits = []> :
     fir_Op<mnemonic,
            !listconcat(traits, [NoSideEffect, SameOperandsAndResultType])>,
-    Results<(outs AnyType)> {
-  let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);";
-
-  let printer = "return printBinaryOp(this->getOperation(), p);";
+    Results<(outs AnyType:$result)> {
+  let assemblyFormat = "operands attr-dict `:` type($result)";
 }
 
 class fir_UnaryArithmeticOp<string mnemonic, list<Trait> traits = []> :
       fir_Op<mnemonic,
              !listconcat(traits, [NoSideEffect, SameOperandsAndResultType])>,
-      Results<(outs AnyType)> {
-  let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);";
-
-  let printer = "return printUnaryOp(this->getOperation(), p);";
+      Results<(outs AnyType:$result)> {
+  let assemblyFormat = "operands attr-dict `:` type($result)";
 }
 
 def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index bc91111812a43..33d34b6ad51a8 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3211,26 +3211,6 @@ mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser,
   return mlir::success();
 }
 
-/// Generic pretty-printer of a binary operation
-static void printBinaryOp(Operation *op, OpAsmPrinter &p) {
-  assert(op->getNumOperands() == 2 && "binary op must have two operands");
-  assert(op->getNumResults() == 1 && "binary op must have one result");
-
-  p << ' ' << op->getOperand(0) << ", " << op->getOperand(1);
-  p.printOptionalAttrDict(op->getAttrs());
-  p << " : " << op->getResult(0).getType();
-}
-
-/// Generic pretty-printer of an unary operation
-static void printUnaryOp(Operation *op, OpAsmPrinter &p) {
-  assert(op->getNumOperands() == 1 && "unary op must have one operand");
-  assert(op->getNumResults() == 1 && "unary op must have one result");
-
-  p << ' ' << op->getOperand(0);
-  p.printOptionalAttrDict(op->getAttrs());
-  p << " : " << op->getResult(0).getType();
-}
-
 bool fir::isReferenceLike(mlir::Type type) {
   return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() ||
          type.isa<fir::PointerType>();

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 78b3af47ae19b..4a40132df963a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -419,8 +419,7 @@ class LLVM_CastOp<string mnemonic, string builderFunc, Type type,
   let arguments = (ins type:$arg);
   let results = (outs resultType:$res);
   let builders = [LLVM_OneResultOpBuilder];
-  let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
-  let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
+  let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
 }
 def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast",
                                  LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> {

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 79ad1ed7f8046..abb8231d7e5b2 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -383,7 +383,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
   }];
 
   let hasFolder = 1;
-  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index ec9116c5d803f..c600b31676b3b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -34,6 +34,7 @@ class SPV_ArithmeticBinaryOp<string mnemonic, Type type,
   let results = (outs
     SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result
   );
+  let assemblyFormat = "operands attr-dict `:` type($result)";
 }
 
 class SPV_ArithmeticUnaryOp<string mnemonic, Type type,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 625e44681cae8..055835ea48081 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4338,8 +4338,6 @@ class SPV_BinaryOp<string mnemonic, Type resultType, Type operandsType,
     SPV_ScalarOrVectorOf<resultType>:$result
   );
 
-  let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
-  let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
   // No additional verification needed in addition to the ODS-generated ones.
   let hasVerifier = 0;
 }

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index e0f16032badd1..b9b31f6dd5abb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -21,7 +21,9 @@ class SPV_BitBinaryOp<string mnemonic, list<Trait> traits = []> :
       // All the operands type used in bit instructions are SPV_Integer.
       SPV_BinaryOp<mnemonic, SPV_Integer, SPV_Integer,
                    !listconcat(traits,
-                               [NoSideEffect, SameOperandsAndResultType])>;
+                               [NoSideEffect, SameOperandsAndResultType])> {                                 
+  let assemblyFormat = "operands attr-dict `:` type($result)";
+}
 
 class SPV_BitFieldExtractOp<string mnemonic, list<Trait> traits = []> :
       SPV_Op<mnemonic, !listconcat(traits,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index ec2c73f03ac25..5048dd10ae575 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -29,9 +29,9 @@ class SPV_CastOp<string mnemonic, Type resultType, Type operandType,
   let results = (outs
     SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result
   );
-
-  let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
-  let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
+  let assemblyFormat = [{
+    $operand attr-dict `:` type($operand) `to` type($result)
+  }];
 }
 
 // -----
@@ -85,9 +85,9 @@ def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> {
     SPV_ScalarOrVectorOrPtr:$result
   );
 
-  let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
-  let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
-
+  let assemblyFormat = [{
+    $operand attr-dict `:` type($operand) `to` type($result)
+  }];
   let hasCanonicalizer = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
index 1532d3ea7133a..f0d5515d2bcf7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
@@ -72,10 +72,6 @@ class SPV_GLSLBinaryOp<string mnemonic, int opcode, Type resultType,
     SPV_ScalarOrVectorOf<resultType>:$result
   );
 
-  let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
-
-  let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
-
   let hasVerifier = 0;
 }
 
@@ -83,7 +79,10 @@ class SPV_GLSLBinaryOp<string mnemonic, int opcode, Type resultType,
 // return type matches.
 class SPV_GLSLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
                                  list<Trait> traits = []> :
-  SPV_GLSLBinaryOp<mnemonic, opcode, type, type, traits>;
+  SPV_GLSLBinaryOp<mnemonic, opcode, type, type,
+                   traits # [SameOperandsAndResultType]> {
+  let assemblyFormat = "operands attr-dict `:` type($result)";
+}
 
 // Base class for GLSL ternary ops.
 class SPV_GLSLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
@@ -100,9 +99,8 @@ class SPV_GLSLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
     SPV_ScalarOrVectorOf<type>:$result
   );
 
-  let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
-
-  let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
+  let parser = [{ return parseOneResultSameOperandTypeOp(parser, result); }];
+  let printer = [{ return printOneResultOp(getOperation(), p); }];
 
   let hasVerifier = 0;
 }

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
index 92fb8a46478a2..0b3c08a510114 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
@@ -71,10 +71,6 @@ class SPV_OCLBinaryOp<string mnemonic, int opcode, Type resultType,
     SPV_ScalarOrVectorOf<resultType>:$result
   );
 
-  let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
-
-  let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
-
   let hasVerifier = 0;
 }
 
@@ -82,7 +78,10 @@ class SPV_OCLBinaryOp<string mnemonic, int opcode, Type resultType,
 // return type matches.
 class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
                                 list<Trait> traits = []> :
-  SPV_OCLBinaryOp<mnemonic, opcode, type, type, traits>;
+  SPV_OCLBinaryOp<mnemonic, opcode, type, type,
+                  traits # [SameOperandsAndResultType]> {                    
+  let assemblyFormat = "operands attr-dict `:` type($result)";
+}
 
 // -----
 

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index ee745e36769aa..f4a499a4bd7bd 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -14,6 +14,7 @@
 #define SHAPE_OPS
 
 include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -331,7 +332,9 @@ def Shape_RankOp : Shape_Op<"rank",
   }];
 }
 
-def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
+def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
+    DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
+  ]> {
   let summary = "Creates a dimension tensor from a shape";
   let description = [{
     Converts a shape to a 1D integral tensor of extents. The number of elements
@@ -624,7 +627,9 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
   }];
 }
 
-def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
+def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [
+    DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
+  ]> {
   let summary = "Casts between index types of the shape and standard dialect";
   let description = [{
     Converts a `shape.size` to a standard index. This operation and its

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index ac51c7bb51ddb..e29b35f211857 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1897,26 +1897,9 @@ class OpInterface
 };
 
 //===----------------------------------------------------------------------===//
-// Common Operation Folders/Parsers/Printers
+// CastOpInterface utilities
 //===----------------------------------------------------------------------===//
 
-// These functions are out-of-line implementations of the methods in UnaryOp and
-// BinaryOp, which avoids them being template instantiated/duplicated.
-namespace impl {
-ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
-                                           OperationState &result);
-
-void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
-                   Value rhs);
-ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
-                                            OperationState &result);
-
-// Prints the given binary `op` in custom assembly form if both the two operands
-// and the result have the same time. Otherwise, prints the generic assembly
-// form.
-void printOneResultOp(Operation *op, OpAsmPrinter &p);
-} // namespace impl
-
 // These functions are out-of-line implementations of the methods in
 // CastOpInterface, which avoids them being template instantiated/duplicated.
 namespace impl {
@@ -1927,20 +1910,6 @@ LogicalResult foldCastInterfaceOp(Operation *op,
 /// Attempt to verify the given cast operation.
 LogicalResult verifyCastInterfaceOp(
     Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
-
-// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
-// need for them, but some older ODS code in `std` still depends on them).
-void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
-                 Type destType);
-ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
-void printCastOp(Operation *op, OpAsmPrinter &p);
-// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
-// when all uses have been updated. Also, consider adding functionality to
-// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
-// generically.
-Value foldCastOp(Operation *op);
-LogicalResult verifyCastOp(Operation *op,
-                           function_ref<bool(Type, Type)> areCastCompatible);
 } // namespace impl
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index da672dc86b66d..33192e1bb704b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -65,10 +65,6 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
   return NoneType::get(type.getContext());
 }
 
-LogicalResult memref::CastOp::verify() {
-  return impl::verifyCastOp(*this, areCastCompatible);
-}
-
 //===----------------------------------------------------------------------===//
 // AllocOp / AllocaOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index cb476dcb62307..1d060473547d3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -64,6 +64,54 @@ static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
 // Common utility functions
 //===----------------------------------------------------------------------===//
 
+static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
+                                                   OperationState &result) {
+  SmallVector<OpAsmParser::OperandType, 2> ops;
+  Type type;
+  // If the operand list is in-between parentheses, then we have a generic form.
+  // (see the fallback in `printOneResultOp`).
+  SMLoc loc = parser.getCurrentLocation();
+  if (!parser.parseOptionalLParen()) {
+    if (parser.parseOperandList(ops) || parser.parseRParen() ||
+        parser.parseOptionalAttrDict(result.attributes) ||
+        parser.parseColon() || parser.parseType(type))
+      return failure();
+    auto fnType = type.dyn_cast<FunctionType>();
+    if (!fnType) {
+      parser.emitError(loc, "expected function type");
+      return failure();
+    }
+    if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
+      return failure();
+    result.addTypes(fnType.getResults());
+    return success();
+  }
+  return failure(parser.parseOperandList(ops) ||
+                 parser.parseOptionalAttrDict(result.attributes) ||
+                 parser.parseColonType(type) ||
+                 parser.resolveOperands(ops, type, result.operands) ||
+                 parser.addTypeToList(type, result.types));
+}
+
+static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
+  assert(op->getNumResults() == 1 && "op should have one result");
+
+  // If not all the operand and result types are the same, just use the
+  // generic assembly form to avoid omitting information in printing.
+  auto resultType = op->getResult(0).getType();
+  if (llvm::any_of(op->getOperandTypes(),
+                   [&](Type type) { return type != resultType; })) {
+    p.printGenericOp(op, /*printOpName=*/false);
+    return;
+  }
+
+  p << ' ';
+  p.printOperands(op->getOperands());
+  p.printOptionalAttrDict(op->getAttrs());
+  // Now we can output only one type for all operands and the result.
+  p << " : " << resultType;
+}
+
 /// Returns true if the given op is a function-like op or nested in a
 /// function-like op without a module-like op in the middle.
 static bool isNestedInFunctionOpInterface(Operation *op) {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index a25f6dd7b1cb5..9241e2d12ba22 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1692,7 +1692,7 @@ OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
   // `IntegerAttr`s which makes constant folding simple.
   if (Attribute arg = operands[0])
     return arg;
-  return impl::foldCastOp(*this);
+  return OpFoldResult();
 }
 
 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -1700,6 +1700,12 @@ void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<IndexToSizeToIndexCanonicalization>(context);
 }
 
+bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
+}
+
 //===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//
@@ -1750,7 +1756,7 @@ LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
 
 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
   if (!operands[0])
-    return impl::foldCastOp(*this);
+    return OpFoldResult();
   Builder builder(getContext());
   auto shape = llvm::to_vector<6>(
       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
@@ -1759,6 +1765,21 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
   return DenseIntElementsAttr::get(type, shape);
 }
 
+bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
+    if (!inputTensor.getElementType().isa<IndexType>() ||
+        inputTensor.getRank() != 1 || !inputTensor.isDynamicDim(0))
+      return false;
+  } else if (!inputs[0].isa<ShapeType>()) {
+    return false;
+  }
+
+  TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
+  return outputTensor && outputTensor.getElementType().isa<IndexType>();
+}
+
 //===----------------------------------------------------------------------===//
 // ReduceOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index e67933790de11..72738971db6d5 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1125,69 +1125,7 @@ bool OpTrait::hasElementwiseMappableTraits(Operation *op) {
 }
 
 //===----------------------------------------------------------------------===//
-// BinaryOp implementation
-//===----------------------------------------------------------------------===//
-
-// These functions are out-of-line implementations of the methods in BinaryOp,
-// which avoids them being template instantiated/duplicated.
-
-void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
-                         Value rhs) {
-  assert(lhs.getType() == rhs.getType());
-  result.addOperands({lhs, rhs});
-  result.types.push_back(lhs.getType());
-}
-
-ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser,
-                                                  OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 2> ops;
-  Type type;
-  // If the operand list is in-between parentheses, then we have a generic form.
-  // (see the fallback in `printOneResultOp`).
-  SMLoc loc = parser.getCurrentLocation();
-  if (!parser.parseOptionalLParen()) {
-    if (parser.parseOperandList(ops) || parser.parseRParen() ||
-        parser.parseOptionalAttrDict(result.attributes) ||
-        parser.parseColon() || parser.parseType(type))
-      return failure();
-    auto fnType = type.dyn_cast<FunctionType>();
-    if (!fnType) {
-      parser.emitError(loc, "expected function type");
-      return failure();
-    }
-    if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
-      return failure();
-    result.addTypes(fnType.getResults());
-    return success();
-  }
-  return failure(parser.parseOperandList(ops) ||
-                 parser.parseOptionalAttrDict(result.attributes) ||
-                 parser.parseColonType(type) ||
-                 parser.resolveOperands(ops, type, result.operands) ||
-                 parser.addTypeToList(type, result.types));
-}
-
-void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) {
-  assert(op->getNumResults() == 1 && "op should have one result");
-
-  // If not all the operand and result types are the same, just use the
-  // generic assembly form to avoid omitting information in printing.
-  auto resultType = op->getResult(0).getType();
-  if (llvm::any_of(op->getOperandTypes(),
-                   [&](Type type) { return type != resultType; })) {
-    p.printGenericOp(op, /*printOpName=*/false);
-    return;
-  }
-
-  p << ' ';
-  p.printOperands(op->getOperands());
-  p.printOptionalAttrDict(op->getAttrs());
-  // Now we can output only one type for all operands and the result.
-  p << " : " << resultType;
-}
-
-//===----------------------------------------------------------------------===//
-// CastOp implementation
+// CastOpInterface
 //===----------------------------------------------------------------------===//
 
 /// Attempt to fold the given cast operation.
@@ -1232,50 +1170,6 @@ LogicalResult impl::verifyCastInterfaceOp(
   return success();
 }
 
-void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source,
-                       Type destType) {
-  result.addOperands(source);
-  result.addTypes(destType);
-}
-
-ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::OperandType srcInfo;
-  Type srcType, dstType;
-  return failure(parser.parseOperand(srcInfo) ||
-                 parser.parseOptionalAttrDict(result.attributes) ||
-                 parser.parseColonType(srcType) ||
-                 parser.resolveOperand(srcInfo, srcType, result.operands) ||
-                 parser.parseKeywordType("to", dstType) ||
-                 parser.addTypeToList(dstType, result.types));
-}
-
-void impl::printCastOp(Operation *op, OpAsmPrinter &p) {
-  p << ' ' << op->getOperand(0);
-  p.printOptionalAttrDict(op->getAttrs());
-  p << " : " << op->getOperand(0).getType() << " to "
-    << op->getResult(0).getType();
-}
-
-Value impl::foldCastOp(Operation *op) {
-  // Identity cast
-  if (op->getOperand(0).getType() == op->getResult(0).getType())
-    return op->getOperand(0);
-  return nullptr;
-}
-
-LogicalResult
-impl::verifyCastOp(Operation *op,
-                   function_ref<bool(Type, Type)> areCastCompatible) {
-  auto opType = op->getOperand(0).getType();
-  auto resType = op->getResult(0).getType();
-  if (!areCastCompatible(opType, resType))
-    return op->emitError("operand type ")
-           << opType << " and result type " << resType
-           << " are cast incompatible";
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Misc. utils
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list