[Mlir-commits] [mlir] 5bdb21d - [mlir] Use assemblyFormat in AllocLikeOp.

Christian Sigg llvmlistbot at llvm.org
Wed Nov 11 01:27:32 PST 2020


Author: Christian Sigg
Date: 2020-11-11T10:27:20+01:00
New Revision: 5bdb21df21c6c78554a99754ef14da06a85f9910

URL: https://github.com/llvm/llvm-project/commit/5bdb21df21c6c78554a99754ef14da06a85f9910
DIFF: https://github.com/llvm/llvm-project/commit/5bdb21df21c6c78554a99754ef14da06a85f9910.diff

LOG: [mlir] Use assemblyFormat in AllocLikeOp.

Split operands into dynamicSizes and symbolOperands.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D90589

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Transforms/Utils/Utils.cpp
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index d87869270c21..902220b2d9ce 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -313,16 +313,6 @@ llvm::Optional<SmallVector<bool, 4>>
 computeRankReductionMask(ArrayRef<int64_t> originalShape,
                          ArrayRef<int64_t> reducedShape);
 
-/// Prints dimension and symbol list.
-void printDimAndSymbolList(Operation::operand_iterator begin,
-                           Operation::operand_iterator end, unsigned numDims,
-                           OpAsmPrinter &p);
-
-/// Parses dimension and symbol list and returns true if parsing failed.
-ParseResult parseDimAndSymbolList(OpAsmParser &parser,
-                                  SmallVectorImpl<Value> &operands,
-                                  unsigned &numDims);
-
 /// Determines whether MemRefCastOp casts to a more dynamic version of the
 /// source memref. This is useful to to fold a memref_cast into a consuming op
 /// and implement canonicalization patterns for ops in 
diff erent dialects that

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index d15f06b37fa5..652efa70fe06 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -165,20 +165,38 @@ class ComplexFloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
 class AllocLikeOp<string mnemonic,
                   Resource resource,
                   list<OpTrait> traits = []> :
-    Std_Op<mnemonic, !listconcat([MemoryEffects<[MemAlloc<resource>]>], traits)> {
-
-  let arguments = (ins Variadic<Index>:$value,
+    Std_Op<mnemonic,
+    !listconcat([
+      MemoryEffects<[MemAlloc<resource>]>,
+      AttrSizedOperandSegments
+    ], traits)> {
+
+  let arguments = (ins Variadic<Index>:$dynamicSizes,
+                   // The symbolic operands (the ones in square brackets) bind
+                   // to the symbols of the memref's layout map.
+                   Variadic<Index>:$symbolOperands,
                    Confined<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$alignment);
-  let results = (outs Res<AnyMemRef, "", [MemAlloc<resource>]>);
+  let results = (outs Res<AnyMemRef, "", [MemAlloc<resource>]>:$memref);
 
   let builders = [
-    OpBuilderDAG<(ins "MemRefType":$memrefType), [{
-      $_state.types.push_back(memrefType);
+    OpBuilderDAG<(ins "MemRefType":$memrefType,
+                  CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, memrefType, {}, alignment);
     }]>,
-    OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$operands,
-      CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
-      $_state.addOperands(operands);
+    OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$dynamicSizes,
+                  CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, memrefType, dynamicSizes, {}, alignment);
+    }]>,
+    OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$dynamicSizes,
+                  "ValueRange":$symbolOperands,
+                  CArg<"IntegerAttr", "{}">:$alignment), [{
       $_state.types.push_back(memrefType);
+      $_state.addOperands(dynamicSizes);
+      $_state.addOperands(symbolOperands);
+      $_state.addAttribute(getOperandSegmentSizeAttr(),
+          $_builder.getI32VectorAttr({
+              static_cast<int32_t>(dynamicSizes.size()),
+              static_cast<int32_t>(symbolOperands.size())}));
       if (alignment)
         $_state.addAttribute(getAlignmentAttrName(), alignment);
     }]>];
@@ -188,23 +206,13 @@ class AllocLikeOp<string mnemonic,
 
     MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
 
-    /// Returns the number of symbolic operands (the ones in square brackets),
-    /// which bind to the symbols of the memref's layout map.
-    unsigned getNumSymbolicOperands() {
-      return getNumOperands() - getType().getNumDynamicDims();
-    }
-
-    /// Returns the symbolic operands (the ones in square brackets), which bind
-    /// to the symbols of the memref's layout map.
-    operand_range getSymbolicOperands() {
-      return {operand_begin() + getType().getNumDynamicDims(), operand_end()};
-    }
-
     /// Returns the dynamic sizes for this alloc operation if specified.
-    operand_range getDynamicSizes() { return getOperands(); }
+    operand_range getDynamicSizes() { return dynamicSizes(); }
   }];
 
-  let parser = [{ return ::parseAllocLikeOp(parser, result); }];
+  let assemblyFormat = [{
+    `(`$dynamicSizes`)` (`` `[` $symbolOperands^ `]`)? attr-dict `:` type($memref)
+  }];
 
   let hasCanonicalizer = 1;
 }

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4cb6821c9c15..4c2196ff176f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -297,6 +297,33 @@ static bool isValidAffineIndexOperand(Value value, Region *region) {
   return isValidDim(value, region) || isValidSymbol(value, region);
 }
 
+/// Prints dimension and symbol list.
+static void printDimAndSymbolList(Operation::operand_iterator begin,
+                                  Operation::operand_iterator end,
+                                  unsigned numDims, OpAsmPrinter &printer) {
+  OperandRange operands(begin, end);
+  printer << '(' << operands.take_front(numDims) << ')';
+  if (operands.size() > numDims)
+    printer << '[' << operands.drop_front(numDims) << ']';
+}
+
+/// Parses dimension and symbol list and returns true if parsing failed.
+static ParseResult parseDimAndSymbolList(OpAsmParser &parser,
+                                         SmallVectorImpl<Value> &operands,
+                                         unsigned &numDims) {
+  SmallVector<OpAsmParser::OperandType, 8> opInfos;
+  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
+    return failure();
+  // Store number of dimensions for validation by caller.
+  numDims = opInfos.size();
+
+  // Parse the optional symbol operands.
+  auto indexTy = parser.getBuilder().getIndexType();
+  return failure(parser.parseOperandList(
+                     opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
+                 parser.resolveOperands(opInfos, indexTy, operands));
+}
+
 /// Utility function to verify that a set of operands are valid dimension and
 /// symbol identifiers. The operands should be laid out such that the dimension
 /// operands are before the symbol operands. This function returns failure if

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index d333ddc8e34c..bc584bd628fb 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -169,36 +169,6 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
   return builder.create<ConstantOp>(loc, type, value);
 }
 
-void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
-                                 Operation::operand_iterator end,
-                                 unsigned numDims, OpAsmPrinter &p) {
-  Operation::operand_range operands(begin, end);
-  p << '(' << operands.take_front(numDims) << ')';
-  if (operands.size() != numDims)
-    p << '[' << operands.drop_front(numDims) << ']';
-}
-
-// Parses dimension and symbol list, and sets 'numDims' to the number of
-// dimension operands parsed.
-// Returns 'false' on success and 'true' on error.
-ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
-                                        SmallVectorImpl<Value> &operands,
-                                        unsigned &numDims) {
-  SmallVector<OpAsmParser::OperandType, 8> opInfos;
-  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
-    return failure();
-  // Store number of dimensions for validation by caller.
-  numDims = opInfos.size();
-
-  // Parse the optional symbol operands.
-  auto indexTy = parser.getBuilder().getIndexType();
-  if (parser.parseOperandList(opInfos,
-                              OpAsmParser::Delimiter::OptionalSquare) ||
-      parser.resolveOperands(opInfos, indexTy, operands))
-    return failure();
-  return success();
-}
-
 /// Matches a ConstantIndexOp.
 /// TODO: This should probably just be a general matcher that uses m_Constant
 /// and checks the operation for an index type.
@@ -404,90 +374,37 @@ static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
 //===----------------------------------------------------------------------===//
 
 template <typename AllocLikeOp>
-static void printAllocLikeOp(OpAsmPrinter &p, AllocLikeOp op, StringRef name) {
-  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
-                "applies to only alloc or alloca");
-  p << name;
-
-  // Print dynamic dimension operands.
-  MemRefType type = op.getType();
-  printDimAndSymbolList(op.operand_begin(), op.operand_end(),
-                        type.getNumDynamicDims(), p);
-  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
-  p << " : " << type;
-}
-
-static void print(OpAsmPrinter &p, AllocOp op) {
-  printAllocLikeOp(p, op, "alloc");
-}
-
-static void print(OpAsmPrinter &p, AllocaOp op) {
-  printAllocLikeOp(p, op, "alloca");
-}
-
-static ParseResult parseAllocLikeOp(OpAsmParser &parser,
-                                    OperationState &result) {
-  MemRefType type;
-
-  // Parse the dimension operands and optional symbol operands, followed by a
-  // memref type.
-  unsigned numDimOperands;
-  if (parseDimAndSymbolList(parser, result.operands, numDimOperands) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(type))
-    return failure();
-
-  // Check numDynamicDims against number of question marks in memref type.
-  // Note: this check remains here (instead of in verify()), because the
-  // partition between dim operands and symbol operands is lost after parsing.
-  // Verification still checks that the total number of operands matches
-  // the number of symbols in the affine map, plus the number of dynamic
-  // dimensions in the memref.
-  if (numDimOperands != type.getNumDynamicDims())
-    return parser.emitError(parser.getNameLoc())
-           << "dimension operand count does not equal memref dynamic dimension "
-              "count";
-  result.types.push_back(type);
-  return success();
-}
-
-template <typename AllocLikeOp>
-static LogicalResult verify(AllocLikeOp op) {
+static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
                 "applies to only alloc or alloca");
   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
   if (!memRefType)
     return op.emitOpError("result must be a memref");
 
-  unsigned numSymbols = 0;
-  if (!memRefType.getAffineMaps().empty()) {
-    // Store number of symbols used in affine map (used in subsequent check).
-    AffineMap affineMap = memRefType.getAffineMaps()[0];
-    numSymbols = affineMap.getNumSymbols();
-  }
+  if (static_cast<int64_t>(op.dynamicSizes().size()) !=
+      memRefType.getNumDynamicDims())
+    return op.emitOpError("dimension operand count does not equal memref "
+                          "dynamic dimension count");
 
-  // Check that the total number of operands matches the number of symbols in
-  // the affine map, plus the number of dynamic dimensions specified in the
-  // memref type.
-  unsigned numDynamicDims = memRefType.getNumDynamicDims();
-  if (op.getNumOperands() != numDynamicDims + numSymbols)
+  unsigned numSymbols = 0;
+  if (!memRefType.getAffineMaps().empty())
+    numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
+  if (op.symbolOperands().size() != numSymbols)
     return op.emitOpError(
-        "operand count does not equal dimension plus symbol operand count");
+        "symbol operand count does not equal memref symbol count");
 
-  // Verify that all operands are of type Index.
-  for (auto operandType : op.getOperandTypes())
-    if (!operandType.isIndex())
-      return op.emitOpError("requires operands to be of type Index");
+  return success();
+}
 
-  if (std::is_same<AllocLikeOp, AllocOp>::value)
-    return success();
+static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
 
+static LogicalResult verify(AllocaOp op) {
   // An alloca op needs to have an ancestor with an allocation scope trait.
-  if (!op.template getParentWithTrait<OpTrait::AutomaticAllocationScope>())
+  if (!op.getParentWithTrait<OpTrait::AutomaticAllocationScope>())
     return op.emitOpError(
         "requires an ancestor op with AutomaticAllocationScope trait");
 
-  return success();
+  return verifyAllocLikeOp(op);
 }
 
 namespace {

diff  --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 516f8c060a93..cef0a827f08d 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -401,7 +401,7 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
   // Fetch a new memref type after normalizing the old memref to have an
   // identity map layout.
   MemRefType newMemRefType =
-      normalizeMemRefType(memrefType, b, allocOp.getNumSymbolicOperands());
+      normalizeMemRefType(memrefType, b, allocOp.symbolOperands().size());
   if (newMemRefType == memrefType)
     // Either memrefType already had an identity map or the map couldn't be
     // transformed to an identity map.
@@ -409,9 +409,9 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
 
   Value oldMemRef = allocOp.getResult();
 
-  SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands());
+  SmallVector<Value, 4> symbolOperands(allocOp.symbolOperands());
   AllocOp newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType,
-                                       llvm::None, allocOp.alignmentAttr());
+                                       allocOp.alignmentAttr());
   AffineMap layoutMap = memrefType.getAffineMaps().front();
   // Replace all uses of the old memref.
   if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 76aff5c6d401..eb2477438649 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -87,7 +87,8 @@ func @bad_alloc_wrong_dynamic_dim_count() {
 ^bb0:
   %0 = constant 7 : index
   // Test alloc with wrong number of dynamic dimensions.
-  %1 = alloc(%0)[%1] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // expected-error {{op 'std.alloc' dimension operand count does not equal memref dynamic dimension count}}
+  // expected-error at +1 {{dimension operand count does not equal memref dynamic dimension count}}
+  %1 = alloc(%0)[%0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
   return
 }
 
@@ -97,7 +98,8 @@ func @bad_alloc_wrong_symbol_count() {
 ^bb0:
   %0 = constant 7 : index
   // Test alloc with wrong number of symbols
-  %1 = alloc(%0) : memref<2x?xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // expected-error {{operand count does not equal dimension plus symbol operand count}}
+  // expected-error at +1 {{symbol operand count does not equal memref symbol count}}
+  %1 = alloc(%0) : memref<2x?xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
   return
 }
 


        


More information about the Mlir-commits mailing list