[Mlir-commits] [mlir] c247081 - [mlir] NFC - Refactor and expose a helper printOffsetSizesAndStrides helper function.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Nov 24 12:02:14 PST 2020


Author: Nicolas Vasilache
Date: 2020-11-24T20:00:59Z
New Revision: c24708102501115efae27f82c24d5991059a5770

URL: https://github.com/llvm/llvm-project/commit/c24708102501115efae27f82c24d5991059a5770
DIFF: https://github.com/llvm/llvm-project/commit/c24708102501115efae27f82c24d5991059a5770.diff

LOG: [mlir] NFC - Refactor and expose a helper printOffsetSizesAndStrides helper function.

Print part of an op of the form:
```
  <optional-offset-prefix>`[` offset-list `]`
  <optional-size-prefix>`[` size-list `]`
  <optional-stride-prefix>[` stride-list `]`
```

Also address some leftover nits.

Differential revision: https://reviews.llvm.org/D92031

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/include/mlir/Interfaces/ViewLikeInterface.td
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Interfaces/ViewLikeInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 7a775c3a317b..c44d99b1620d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -232,15 +232,6 @@ class BaseOpWithOffsetSizesAndStrides<string mnemonic, list<OpTrait> traits = []
     SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc) {
       return mlir::getOrCreateRanges(*this, b, loc);
     }
-
-    static ArrayRef<StringRef> getSpecialAttrNames() {
-      static SmallVector<StringRef, 4> names{
-        OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
-        OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
-        OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
-        getOperandSegmentSizeAttr()};
-      return names;
-   }
   }];
 }
 

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 15ba5d18a6d6..b7d796f39f4d 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -29,6 +29,24 @@ struct Range {
 
 class OffsetSizeAndStrideOpInterface;
 LogicalResult verify(OffsetSizeAndStrideOpInterface op);
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/ViewLikeInterface.h.inc"
+
+namespace mlir {
+/// Print part of an op of the form:
+/// ```
+///   <optional-offset-prefix>`[` offset-list `]`
+///   <optional-size-prefix>`[` size-list `]`
+///   <optional-stride-prefix>[` stride-list `]`
+/// ```
+void printOffsetsSizesAndStrides(
+    OpAsmPrinter &p, OffsetSizeAndStrideOpInterface op,
+    StringRef offsetPrefix = "", StringRef sizePrefix = " ",
+    StringRef stridePrefix = " ",
+    ArrayRef<StringRef> elidedAttrs =
+        OffsetSizeAndStrideOpInterface::getSpecialAttrNames());
 
 /// Parse trailing part of an op of the form:
 /// ```
@@ -59,10 +77,16 @@ ParseResult parseOffsetsSizesAndStrides(
         nullptr,
     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
         nullptr);
+/// `preResolutionFn`-less version of `parseOffsetsSizesAndStrides`.
+ParseResult parseOffsetsSizesAndStrides(
+    OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
+    llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix =
+        nullptr,
+    llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix =
+        nullptr,
+    llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
+        nullptr);
 
 } // namespace mlir
 
-/// Include the generated interface declarations.
-#include "mlir/Interfaces/ViewLikeInterface.h.inc"
-
 #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index d3a7bf185d13..31f9bca8d7fb 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -357,6 +357,14 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
     static StringRef getStaticStridesAttrName() {
       return "static_strides";
     }
+    static ArrayRef<StringRef> getSpecialAttrNames() {
+      static SmallVector<StringRef, 4> names{
+        OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
+        OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
+        OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
+        OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr()};
+      return names;
+   }
   }];
 
   let verify = [{

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 3160e8f8be0b..6ed17c744f8c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -793,12 +793,6 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
   return fusableDependences;
 }
 
-static bool isZero(Value v) {
-  if (auto cst = v.getDefiningOp<ConstantIndexOp>())
-    return cst.getValue() == 0;
-  return false;
-}
-
 /// Tile the fused loops in the root operation, by setting the tile sizes for
 /// all other loops to zero (those will be tiled later).
 static Optional<TiledLinalgOp> tileRootOperation(

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 1fe52b70992e..1437552f6e2a 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -248,49 +248,6 @@ OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
                                         [](APInt a, APInt b) { return a + b; });
 }
 
-//===----------------------------------------------------------------------===//
-// BaseOpWithOffsetSizesAndStridesOp
-//===----------------------------------------------------------------------===//
-
-/// Print a list with either (1) the static integer value in `arrayAttr` if
-/// `isDynamic` evaluates to false or (2) the next value otherwise.
-/// This allows idiomatic printing of mixed value and integer attributes in a
-/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-static void
-printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
-                              ArrayAttr arrayAttr,
-                              llvm::function_ref<bool(int64_t)> isDynamic) {
-  p << '[';
-  unsigned idx = 0;
-  llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
-    int64_t val = a.cast<IntegerAttr>().getInt();
-    if (isDynamic(val))
-      p << values[idx++];
-    else
-      p << val;
-  });
-  p << ']';
-}
-
-/// Verify that a particular offset/size/stride static attribute is well-formed.
-static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
-    OffsetSizeAndStrideOpInterface op, StringRef name,
-    unsigned expectedNumElements, StringRef attrName, ArrayAttr attr,
-    llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
-  /// Check static and dynamic offsets/sizes/strides breakdown.
-  if (attr.size() != expectedNumElements)
-    return op.emitError("expected ")
-           << expectedNumElements << " " << name << " values";
-  unsigned expectedNumDynamicEntries =
-      llvm::count_if(attr.getValue(), [&](Attribute attr) {
-        return isDynamic(attr.cast<IntegerAttr>().getInt());
-      });
-  if (values.size() != expectedNumDynamicEntries)
-    return op.emitError("expected ")
-           << expectedNumDynamicEntries << " dynamic " << name << " values";
-  return success();
-}
-
 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
   return llvm::to_vector<4>(
@@ -2390,9 +2347,9 @@ void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result,
         staticStridesVector, offset, sizes, strides, attrs);
 }
 
-/// Print of the form:
+/// Print a memref_reinterpret_cast op of the form:
 /// ```
-///   `name` ssa-name to
+///   `memref_reinterpret_cast` ssa-name to
 ///       offset: `[` offset `]`
 ///       sizes: `[` size-list `]`
 ///       strides:`[` stride-list `]`
@@ -2400,19 +2357,11 @@ void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result,
 /// ```
 static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) {
   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
-  p << op.getOperationName().drop_front(stdDotLen) << " " << op.source()
-    << " to offset: ";
-  printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
-                                ShapedType::isDynamicStrideOrOffset);
-  p << ", sizes: ";
-  printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
-                                ShapedType::isDynamic);
-  p << ", strides: ";
-  printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
-                                ShapedType::isDynamicStrideOrOffset);
-  p.printOptionalAttrDict(
-      op.getAttrs(),
-      /*elidedAttrs=*/MemRefReinterpretCastOp::getSpecialAttrNames());
+  p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
+  p << op.source() << " ";
+  printOffsetsSizesAndStrides(
+      p, op, /*offsetPrefix=*/"to offset: ", /*sizePrefix=*/", sizes: ",
+      /*stridePrefix=*/", strides: ");
   p << ": " << op.source().getType() << " to " << op.getType();
 }
 
@@ -2451,8 +2400,8 @@ static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser,
                    parser.parseKeywordType("to", dstType) ||
                    parser.resolveOperand(srcInfo, srcType, result.operands));
   };
-  SmallVector<int, 4> segmentSizes{1}; // source memref
-  if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes,
+  if (failed(parseOffsetsSizesAndStrides(parser, result,
+                                         /*segmentSizes=*/{1}, // source memref
                                          preResolutionFn, parseOffsetPrefix,
                                          parseSizePrefix, parseStridePrefix)))
     return failure();
@@ -3122,38 +3071,18 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
       sourceMemRefType.getMemorySpace());
 }
 
-/// Print SubViewOp in the form:
+/// Print a subview op of the form:
 /// ```
-///   subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+///   `subview` ssa-name
+///     `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
 ///     `:` strided-memref-type `to` strided-memref-type
 /// ```
-template <typename OpType>
-static void printOpWithOffsetsSizesAndStrides(
-    OpAsmPrinter &p, OpType op,
-    llvm::function_ref<void(OpAsmPrinter &p, OpType op)> printExtraOperands =
-        [](OpAsmPrinter &p, OpType op) {},
-    StringRef resultTypeKeyword = "to") {
+static void print(OpAsmPrinter &p, SubViewOp op) {
   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
   p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
   p << op.source();
-  printExtraOperands(p, op);
-  printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
-                                ShapedType::isDynamicStrideOrOffset);
-  p << ' ';
-  printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
-                                ShapedType::isDynamic);
-  p << ' ';
-  printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
-                                ShapedType::isDynamicStrideOrOffset);
-  p << ' ';
-  p.printOptionalAttrDict(op.getAttrs(),
-                          /*elidedAttrs=*/{OpType::getSpecialAttrNames()});
-  p << " : " << op.getSourceType() << " " << resultTypeKeyword << " "
-    << op.getType();
-}
-
-static void print(OpAsmPrinter &p, SubViewOp op) {
-  return printOpWithOffsetsSizesAndStrides<SubViewOp>(p, op);
+  printOffsetsSizesAndStrides(p, op);
+  p << " : " << op.getSourceType() << " to " << op.getType();
 }
 
 /// Parse a subview op of the form:
@@ -3173,8 +3102,9 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
                    parser.parseKeywordType("to", dstType) ||
                    parser.resolveOperand(srcInfo, srcType, result.operands));
   };
-  SmallVector<int, 4> segmentSizes{1}; // source memref
-  if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes,
+
+  if (failed(parseOffsetsSizesAndStrides(parser, result,
+                                         /*segmentSizes=*/{1}, // source memref
                                          preResolutionFn)))
     return failure();
   return parser.addTypeToList(dstType, result.types);
@@ -3750,8 +3680,18 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
 // SubTensorOp
 //===----------------------------------------------------------------------===//
 
+/// Print a subtensor op of the form:
+/// ```
+///   `subtensor` ssa-name
+///     `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+///     `:` ranked-tensor-type `to` ranked-tensor-type
+/// ```
 static void print(OpAsmPrinter &p, SubTensorOp op) {
-  return printOpWithOffsetsSizesAndStrides<SubTensorOp>(p, op);
+  int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+  p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
+  p << op.source();
+  printOffsetsSizesAndStrides(p, op);
+  p << " : " << op.getSourceType() << " to " << op.getType();
 }
 
 /// Parse a subtensor op of the form:
@@ -3772,8 +3712,9 @@ static ParseResult parseSubTensorOp(OpAsmParser &parser,
                    parser.parseKeywordType("to", dstType) ||
                    parser.resolveOperand(srcInfo, srcType, result.operands));
   };
-  SmallVector<int, 4> segmentSizes{1}; // source tensor
-  if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes,
+
+  if (failed(parseOffsetsSizesAndStrides(parser, result,
+                                         /*segmentSizes=*/{1}, // source tensor
                                          preResolutionFn)))
     return failure();
   return parser.addTypeToList(dstType, result.types);
@@ -3853,11 +3794,18 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 // SubTensorInsertOp
 //===----------------------------------------------------------------------===//
 
+/// Print a subtensor_insert op of the form:
+/// ```
+///   `subtensor_insert` ssa-name `into` ssa-name
+///     `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+///     `:` ranked-tensor-type `into` ranked-tensor-type
+/// ```
 static void print(OpAsmPrinter &p, SubTensorInsertOp op) {
-  return printOpWithOffsetsSizesAndStrides<SubTensorInsertOp>(
-      p, op,
-      [](OpAsmPrinter &p, SubTensorInsertOp op) { p << " into " << op.dest(); },
-      /*resultTypeKeyword=*/"into");
+  int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+  p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
+  p << op.source() << " into " << op.dest();
+  printOffsetsSizesAndStrides(p, op);
+  p << " : " << op.getSourceType() << " into " << op.getType();
 }
 
 /// Parse a subtensor_insert op of the form:
@@ -3880,9 +3828,11 @@ static ParseResult parseSubTensorInsertOp(OpAsmParser &parser,
                    parser.resolveOperand(srcInfo, srcType, result.operands) ||
                    parser.resolveOperand(dstInfo, dstType, result.operands));
   };
-  SmallVector<int, 4> segmentSizes{1, 1}; // source tensor, destination tensor
-  if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes,
-                                         preResolutionFn)))
+
+  if (failed(parseOffsetsSizesAndStrides(
+          parser, result,
+          /*segmentSizes=*/{1, 1}, // source tensor, destination tensor
+          preResolutionFn)))
     return failure();
   return parser.addTypeToList(dstType, result.types);
 }

diff  --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index d8a540fa72ff..6127d08a8fc5 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -57,6 +57,44 @@ LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
   return success();
 }
 
+/// Print a list with either (1) the static integer value in `arrayAttr` if
+/// `isDynamic` evaluates to false or (2) the next value otherwise.
+/// This allows idiomatic printing of mixed value and integer attributes in a
+/// list. E.g. `[%arg0, 7, 42, %arg42]`.
+static void
+printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
+                              ArrayAttr arrayAttr,
+                              llvm::function_ref<bool(int64_t)> isDynamic) {
+  p << '[';
+  unsigned idx = 0;
+  llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
+    int64_t val = a.cast<IntegerAttr>().getInt();
+    if (isDynamic(val))
+      p << values[idx++];
+    else
+      p << val;
+  });
+  p << ']';
+}
+
+void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p,
+                                       OffsetSizeAndStrideOpInterface op,
+                                       StringRef offsetPrefix,
+                                       StringRef sizePrefix,
+                                       StringRef stridePrefix,
+                                       ArrayRef<StringRef> elidedAttrs) {
+  p << offsetPrefix;
+  printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
+                                ShapedType::isDynamicStrideOrOffset);
+  p << sizePrefix;
+  printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
+                                ShapedType::isDynamic);
+  p << stridePrefix;
+  printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
+                                ShapedType::isDynamicStrideOrOffset);
+  p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
+}
+
 /// Parse a mixed list with either (1) static integer values or (2) SSA values.
 /// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
 /// encode the position of SSA values. Add the parsed SSA values to `ssa`
@@ -105,9 +143,17 @@ parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
 }
 
 ParseResult mlir::parseOffsetsSizesAndStrides(
-    OpAsmParser &parser,
-    OperationState &result,
-    ArrayRef<int> segmentSizes,
+    OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
+    llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
+    llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
+    llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
+  return parseOffsetsSizesAndStrides(
+      parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix,
+      parseOptionalSizePrefix, parseOptionalStridePrefix);
+}
+
+ParseResult mlir::parseOffsetsSizesAndStrides(
+    OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
     llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
         preResolutionFn,
     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
@@ -132,14 +178,14 @@ ParseResult mlir::parseOffsetsSizesAndStrides(
           ShapedType::kDynamicStrideOrOffset, stridesInfo))
     return failure();
   // Add segment sizes to result
-  SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(), segmentSizes.end());
+  SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(),
+                                        segmentSizes.end());
   segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()),
-                                      static_cast<int>(sizesInfo.size()),
-                                      static_cast<int>(stridesInfo.size())});
-  auto b = parser.getBuilder();
+                            static_cast<int>(sizesInfo.size()),
+                            static_cast<int>(stridesInfo.size())});
   result.addAttribute(
       OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
-      b.getI32VectorAttr(segmentSizesFinal));
+      parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
   return failure(
       (preResolutionFn && preResolutionFn(parser, result)) ||
       parser.resolveOperands(offsetsInfo, indexType, result.operands) ||


        


More information about the Mlir-commits mailing list