[Mlir-commits] [mlir] 1af15de - [mlir] Switch {collapse, expand}_shape ops to the declarative assembly format
Benjamin Kramer
llvmlistbot at llvm.org
Thu Feb 17 11:04:42 PST 2022
Author: Benjamin Kramer
Date: 2022-02-17T20:00:05+01:00
New Revision: 1af15de6b77278fec12e72ca8be9f6408fd4761b
URL: https://github.com/llvm/llvm-project/commit/1af15de6b77278fec12e72ca8be9f6408fd4761b
DIFF: https://github.com/llvm/llvm-project/commit/1af15de6b77278fec12e72ca8be9f6408fd4761b.diff
LOG: [mlir] Switch {collapse,expand}_shape ops to the declarative assembly format
Same functionality, a lot less code.
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2445280b01573..9102da3db7877 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1240,9 +1240,12 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
Value getViewSource() { return src(); }
}];
+ let assemblyFormat = [{
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
+ }];
+
let hasFolder = 1;
let hasCanonicalizer = 1;
- let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3a2ec73791d3c..a53f3e7e5ca35 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -732,9 +732,12 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
}
}];
+ let assemblyFormat = [{
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
+ }];
+
let hasFolder = 1;
let hasCanonicalizer = 1;
- let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index b2d4cf1e4bffc..dfeac25fd6c99 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -74,31 +74,6 @@ getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
bool isReassociationValid(ArrayRef<AffineMap> reassociation,
int *invalidIndex = nullptr);
-/// Parse a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
-/// linalg::(Tensor)CollapseShapeOp.
-ParseResult parseReshapeLikeOp(OpAsmParser &parser, OperationState &result);
-
-/// Print a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
-/// linalg::(Tensor)CollapseShapeOp.
-template <typename ReshapeLikeOp>
-void printReshapeOp(OpAsmPrinter &p, ReshapeLikeOp op) {
- p << ' ' << op.src() << " [";
-
- llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) {
- p << '[';
- auto arrayAttr = attr.template cast<ArrayAttr>();
- llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) {
- p << attr.cast<IntegerAttr>().getInt();
- });
- p << ']';
- });
-
- p << "] ";
- p.printOptionalAttrDict(op->getAttrs(),
- /*elidedAttrs=*/{getReassociationAttrName()});
- p << ": " << op.src().getType() << " into " << op.getType();
-}
-
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
ArrayRef<Attribute> operands) {
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b64fb00ce4cc1..541da53cb2a49 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1370,21 +1370,6 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
getReassociationIndices());
}
-ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) {
- return parseReshapeLikeOp(parser, result);
-}
-void ExpandShapeOp::print(OpAsmPrinter &p) {
- ::mlir::printReshapeOp<ExpandShapeOp>(p, *this);
-}
-
-ParseResult CollapseShapeOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseReshapeLikeOp(parser, result);
-}
-void CollapseShapeOp::print(OpAsmPrinter &p) {
- ::mlir::printReshapeOp<CollapseShapeOp>(p, *this);
-}
-
/// Detect whether memref dims [dim, dim + extent) can be reshaped without
/// copies.
static bool isReshapableDimBand(unsigned dim, unsigned extent,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a13a274c28e2a..5edb620d5cc32 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -733,17 +733,6 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
getReassociationIndices());
}
-ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) {
- return parseReshapeLikeOp(parser, result);
-}
-void ExpandShapeOp::print(OpAsmPrinter &p) { printReshapeOp(p, *this); }
-
-ParseResult CollapseShapeOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseReshapeLikeOp(parser, result);
-}
-void CollapseShapeOp::print(OpAsmPrinter &p) { printReshapeOp(p, *this); }
-
/// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
static RankedTensorType
computeTensorReshapeCollapsedType(RankedTensorType type,
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0048abee4194a..fd509621015d2 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -91,62 +91,6 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
return reassociationMap;
}
-ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
- OperationState &result) {
- // Parse the operand.
- OpAsmParser::OperandType src;
- if (parser.parseOperand(src))
- return failure();
-
- // Parse reassociation indices.
- Builder &b = parser.getBuilder();
- SmallVector<Attribute, 4> reassociation;
- if (parser.parseLSquare())
- return failure();
-
- while (true) {
- if (succeeded(parser.parseOptionalRSquare()))
- break;
- if (parser.parseLSquare())
- return failure();
- SmallVector<int64_t> indices;
- while (true) {
- int64_t index;
- if (parser.parseInteger(index))
- return failure();
- indices.push_back(index);
-
- if (succeeded(parser.parseOptionalComma()))
- continue;
- if (failed(parser.parseRSquare()))
- return failure();
- break;
- }
- reassociation.push_back(b.getI64ArrayAttr(indices));
- if (succeeded(parser.parseOptionalComma()))
- continue;
- if (failed(parser.parseRSquare()))
- return failure();
- break;
- }
-
- result.addAttribute(getReassociationAttrName(),
- b.getArrayAttr(reassociation));
-
- // Parse optional attributes.
- parser.parseOptionalAttrDict(result.attributes);
-
- // Parse types.
- Type srcType;
- Type resultType;
- if (parser.parseColon() || parser.parseType(srcType) ||
- parser.resolveOperand(src, srcType, result.operands) ||
- parser.parseKeyword("into") || parser.parseType(resultType))
- return failure();
- result.addTypes(resultType);
- return success();
-}
-
Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
ArrayRef<ReassociationIndices> producerReassociations,
ArrayRef<ReassociationIndices> consumerReassociations,
More information about the Mlir-commits
mailing list