[Mlir-commits] [mlir] c21e88c - [mlir][Tensor] Avoid dropping attributes for `tensor.pad` operations during canonicalization.

Mahesh Ravishankar llvmlistbot at llvm.org
Mon Mar 20 14:04:00 PDT 2023


Author: Mahesh Ravishankar
Date: 2023-03-20T21:03:46Z
New Revision: c21e88cc02617e0f04807a8dcf164b405d67d5e4

URL: https://github.com/llvm/llvm-project/commit/c21e88cc02617e0f04807a8dcf164b405d67d5e4
DIFF: https://github.com/llvm/llvm-project/commit/c21e88cc02617e0f04807a8dcf164b405d67d5e4.diff

LOG: [mlir][Tensor] Avoid dropping attributes for `tensor.pad` operations during canonicalization.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D146440

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index cc8bbd570ef66..3c3fa70e161f7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "llvm/ADT/StringSet.h"
 #include <optional>
 
@@ -461,18 +462,10 @@ struct GenerateLoopNest {
 /// Returns an attribute list that excludes pre-defined attributes.
 template <typename OpTy>
 SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
-  llvm::StringSet<> elidedAttrs;
-  elidedAttrs.insert(op.getAttributeNames().begin(),
-                     op.getAttributeNames().end());
+  auto elidedAttrs = llvm::to_vector(op.getAttributeNames());
   if (isa<linalg::LinalgOp>(op.getOperation()))
-    elidedAttrs.insert(LinalgDialect::kMemoizedIndexingMapsAttrName);
-  SmallVector<NamedAttribute> attrs;
-  for (auto attr : op->getAttrs()) {
-    if (elidedAttrs.count(attr.getName()))
-      continue;
-    attrs.push_back(attr);
-  }
-  return attrs;
+    elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
+  return getPrunedAttributeList(op, elidedAttrs);
 }
 
 } // namespace linalg

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 09b7775dcaae4..66d6dcc7b27ed 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1295,13 +1295,13 @@ def Tensor_PadOp : Tensor_Op<"pad", [
 
   let builders = [
     // Build a PadOp with mixed static and dynamic entries.
-    OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$staticLow,
-      "ArrayRef<int64_t>":$staticHigh, "ValueRange":$low, "ValueRange":$high,
-      CArg<"bool", "false">:$nofold,
+    OpBuilder<(ins "Type":$resultType, "Value":$source,
+      "ArrayRef<int64_t>":$staticLow, "ArrayRef<int64_t>":$staticHigh,
+      "ValueRange":$low, "ValueRange":$high, CArg<"bool", "false">:$nofold,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a PadOp with all dynamic entries.
-    OpBuilder<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
-      CArg<"bool", "false">:$nofold,
+    OpBuilder<(ins "Type":$resultType, "Value":$source, "ValueRange":$low,
+      "ValueRange":$high, CArg<"bool", "false">:$nofold,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a PadOp with mixed static and dynamic entries and custom
     // result type. If the type passed is nullptr, it is inferred.

diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 1297e87714f79..c4f9fa8a6fe05 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -123,6 +123,11 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
                                TypeRange newResultTypes,
                                ValueRange newOperands);
 
+// Get the list of attributes associated with the op, ignoring
+// those with the provided name.
+SmallVector<NamedAttribute>
+getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f2da1088eb04d..9d26e51e04fd5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2518,26 +2518,27 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
   return RankedTensorType::get(inferredShape, sourceType.getElementType());
 }
 
-void PadOp::build(OpBuilder &b, OperationState &result, Value source,
-                  ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
-                  ValueRange low, ValueRange high, bool nofold,
-                  ArrayRef<NamedAttribute> attrs) {
+void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                  Value source, ArrayRef<int64_t> staticLow,
+                  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
+                  bool nofold, ArrayRef<NamedAttribute> attrs) {
   auto sourceType = source.getType().cast<RankedTensorType>();
-  auto resultType = inferResultType(sourceType, staticLow, staticHigh);
+  if (!resultType)
+    resultType = inferResultType(sourceType, staticLow, staticHigh);
   build(b, result, resultType, source, low, high,
         b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
         nofold ? b.getUnitAttr() : UnitAttr());
   result.addAttributes(attrs);
 }
 
-void PadOp::build(OpBuilder &b, OperationState &result, Value source,
-                  ValueRange low, ValueRange high, bool nofold,
+void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                  Value source, ValueRange low, ValueRange high, bool nofold,
                   ArrayRef<NamedAttribute> attrs) {
   auto sourceType = source.getType().cast<RankedTensorType>();
   unsigned rank = sourceType.getRank();
   SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
-  build(b, result, source, staticVector, staticVector, low, high, nofold,
-        attrs);
+  build(b, result, resultType, source, staticVector, staticVector, low, high,
+        nofold, attrs);
 }
 
 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -2635,9 +2636,9 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
     } else {
       auto newOp = rewriter.create<PadOp>(
           padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
-          padTensorOp.getLow(), padTensorOp.getHigh(),
           padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
-          padTensorOp.getNofold());
+          padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
+          getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
       IRMapping mapper;
       padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
 
@@ -2667,9 +2668,10 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
 
     auto replacementOp = rewriter.create<PadOp>(
         padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
-        padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
-        padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
-        padTensorOp.getNofold());
+        padTensorOp.getSource(), padTensorOp.getStaticLow(),
+        padTensorOp.getStaticHigh(), padTensorOp.getLow(),
+        padTensorOp.getHigh(), padTensorOp.getNofold(),
+        getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
     replacementOp.getRegion().takeBody(padTensorOp.getRegion());
 
     rewriter.replaceOp(padTensorOp, replacementOp.getResult());
@@ -2827,7 +2829,8 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
         innerSliceOp.getMixedStrides());
     auto newPadOp = rewriter.create<PadOp>(
         padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
-        padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
+        padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
+        getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
     rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
                                 newPadOp.getRegion().begin());
     rewriter.replaceOp(padOp, newPadOp.getResult());
@@ -2916,8 +2919,9 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
     auto newResultType = RankedTensorType::get(
         newOutDims, padTensorOp.getType().getElementType());
     auto newOp = rewriter.create<PadOp>(
-        padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(),
-        padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold());
+        padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
+        padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
+        getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
 
     IRMapping mapper;
     padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);

diff  --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index b22f42c09da59..49b49ef639708 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IRMapping.h"
+#include "llvm/ADT/StringSet.h"
 
 #include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
 
@@ -114,3 +115,16 @@ Operation *mlir::cloneWithoutRegions(OpBuilder &b, Operation *op,
     state.addRegion();
   return b.create(state);
 }
+
+SmallVector<NamedAttribute>
+mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
+  llvm::StringSet elidedAttrsSet;
+  elidedAttrsSet.insert(elidedAttrs.begin(), elidedAttrs.end());
+  SmallVector<NamedAttribute> attrs;
+  for (auto attr : op->getAttrs()) {
+    if (elidedAttrsSet.count(attr.getName()))
+      continue;
+    attrs.push_back(attr);
+  }
+  return attrs;
+}


        


More information about the Mlir-commits mailing list