[Mlir-commits] [mlir] 1af15de - [mlir] Switch {collapse, expand}_shape ops to the declarative assembly format

Benjamin Kramer llvmlistbot at llvm.org
Thu Feb 17 11:04:42 PST 2022


Author: Benjamin Kramer
Date: 2022-02-17T20:00:05+01:00
New Revision: 1af15de6b77278fec12e72ca8be9f6408fd4761b

URL: https://github.com/llvm/llvm-project/commit/1af15de6b77278fec12e72ca8be9f6408fd4761b
DIFF: https://github.com/llvm/llvm-project/commit/1af15de6b77278fec12e72ca8be9f6408fd4761b.diff

LOG: [mlir] Switch {collapse,expand}_shape ops to the declarative assembly format

Same functionality, a lot less code.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2445280b01573..9102da3db7877 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1240,9 +1240,12 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Value getViewSource() { return src(); }
   }];
 
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let hasFolder = 1;
   let hasCanonicalizer = 1;
-  let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3a2ec73791d3c..a53f3e7e5ca35 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -732,9 +732,12 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     }
   }];
 
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let hasFolder = 1;
   let hasCanonicalizer = 1;
-  let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index b2d4cf1e4bffc..dfeac25fd6c99 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -74,31 +74,6 @@ getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
 bool isReassociationValid(ArrayRef<AffineMap> reassociation,
                           int *invalidIndex = nullptr);
 
-/// Parse a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
-/// linalg::(Tensor)CollapseShapeOp.
-ParseResult parseReshapeLikeOp(OpAsmParser &parser, OperationState &result);
-
-/// Print a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
-/// linalg::(Tensor)CollapseShapeOp.
-template <typename ReshapeLikeOp>
-void printReshapeOp(OpAsmPrinter &p, ReshapeLikeOp op) {
-  p << ' ' << op.src() << " [";
-
-  llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) {
-    p << '[';
-    auto arrayAttr = attr.template cast<ArrayAttr>();
-    llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) {
-      p << attr.cast<IntegerAttr>().getInt();
-    });
-    p << ']';
-  });
-
-  p << "] ";
-  p.printOptionalAttrDict(op->getAttrs(),
-                          /*elidedAttrs=*/{getReassociationAttrName()});
-  p << ": " << op.src().getType() << " into " << op.getType();
-}
-
 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
                                   ArrayRef<Attribute> operands) {

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b64fb00ce4cc1..541da53cb2a49 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1370,21 +1370,6 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
                                             getReassociationIndices());
 }
 
-ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) {
-  return parseReshapeLikeOp(parser, result);
-}
-void ExpandShapeOp::print(OpAsmPrinter &p) {
-  ::mlir::printReshapeOp<ExpandShapeOp>(p, *this);
-}
-
-ParseResult CollapseShapeOp::parse(OpAsmParser &parser,
-                                   OperationState &result) {
-  return parseReshapeLikeOp(parser, result);
-}
-void CollapseShapeOp::print(OpAsmPrinter &p) {
-  ::mlir::printReshapeOp<CollapseShapeOp>(p, *this);
-}
-
 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
 /// copies.
 static bool isReshapableDimBand(unsigned dim, unsigned extent,

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a13a274c28e2a..5edb620d5cc32 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -733,17 +733,6 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
                                             getReassociationIndices());
 }
 
-ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) {
-  return parseReshapeLikeOp(parser, result);
-}
-void ExpandShapeOp::print(OpAsmPrinter &p) { printReshapeOp(p, *this); }
-
-ParseResult CollapseShapeOp::parse(OpAsmParser &parser,
-                                   OperationState &result) {
-  return parseReshapeLikeOp(parser, result);
-}
-void CollapseShapeOp::print(OpAsmPrinter &p) { printReshapeOp(p, *this); }
-
 /// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
 static RankedTensorType
 computeTensorReshapeCollapsedType(RankedTensorType type,

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0048abee4194a..fd509621015d2 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -91,62 +91,6 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
   return reassociationMap;
 }
 
-ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
-                                     OperationState &result) {
-  // Parse the operand.
-  OpAsmParser::OperandType src;
-  if (parser.parseOperand(src))
-    return failure();
-
-  // Parse reassociation indices.
-  Builder &b = parser.getBuilder();
-  SmallVector<Attribute, 4> reassociation;
-  if (parser.parseLSquare())
-    return failure();
-
-  while (true) {
-    if (succeeded(parser.parseOptionalRSquare()))
-      break;
-    if (parser.parseLSquare())
-      return failure();
-    SmallVector<int64_t> indices;
-    while (true) {
-      int64_t index;
-      if (parser.parseInteger(index))
-        return failure();
-      indices.push_back(index);
-
-      if (succeeded(parser.parseOptionalComma()))
-        continue;
-      if (failed(parser.parseRSquare()))
-        return failure();
-      break;
-    }
-    reassociation.push_back(b.getI64ArrayAttr(indices));
-    if (succeeded(parser.parseOptionalComma()))
-      continue;
-    if (failed(parser.parseRSquare()))
-      return failure();
-    break;
-  }
-
-  result.addAttribute(getReassociationAttrName(),
-                      b.getArrayAttr(reassociation));
-
-  // Parse optional attributes.
-  parser.parseOptionalAttrDict(result.attributes);
-
-  // Parse types.
-  Type srcType;
-  Type resultType;
-  if (parser.parseColon() || parser.parseType(srcType) ||
-      parser.resolveOperand(src, srcType, result.operands) ||
-      parser.parseKeyword("into") || parser.parseType(resultType))
-    return failure();
-  result.addTypes(resultType);
-  return success();
-}
-
 Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
     ArrayRef<ReassociationIndices> producerReassociations,
     ArrayRef<ReassociationIndices> consumerReassociations,


        


More information about the Mlir-commits mailing list