[Mlir-commits] [mlir] b4d08df - [mlir] Remove incorrect builders for ExpandShapeOp

Thomas Raoux llvmlistbot at llvm.org
Fri Mar 18 15:31:44 PDT 2022


Author: Thomas Raoux
Date: 2022-03-18T22:31:17Z
New Revision: b4d08dfd9d40493cb41296a95475beb3596c437e

URL: https://github.com/llvm/llvm-project/commit/b4d08dfd9d40493cb41296a95475beb3596c437e
DIFF: https://github.com/llvm/llvm-project/commit/b4d08dfd9d40493cb41296a95475beb3596c437e.diff

LOG: [mlir] Remove incorrect builders for ExpandShapeOp

ExpandShapeOp builder cannot infer the result type since it doesn't know
how the dimension needs to be split. Remove this builder so that it
doesn't get used accidently. Also remove one potential path using it in
generic fusion.

Differential Revision: https://reviews.llvm.org/D122019

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 5c4fbe291b5d2..c7519f1125f12 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1193,41 +1193,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
       [NoSideEffect, ViewLikeOpInterface])>,
     Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)>{
-  let builders = [
-    // Builders for a contracting reshape whose result type is computed from
-    // `src` and `reassociation`.
-    OpBuilder<(ins "Value":$src,
-      "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
-    OpBuilder<(ins "Value":$src,
-      "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
-    [{
-      auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, src, reassociationMaps, attrs);
-    }]>,
-
-    // Builders for a reshape whose result type is passed explicitly. This may
-    // be either a contracting or expanding reshape.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
-    [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-                          getReassociationIndicesAttribute($_builder, reassociation));
-    }]>,
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
-    [{
-      auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
-    }]>
-  ];
-
+  
   code commonExtraClassDeclaration = [{
     SmallVector<AffineMap, 4> getReassociationMaps();
     SmallVector<ReassociationExprs, 4> getReassociationExprs();
@@ -1288,6 +1254,25 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
       memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
     ```
   }];
+  let builders = [
+    // Builders using ReassociationIndices.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+                          getReassociationIndicesAttribute($_builder, reassociation));
+    }]>,
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      auto reassociationMaps =
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+    }]>
+  ];
   let extraClassDeclaration = commonExtraClassDeclaration;
   let hasVerifier = 1;
 }
@@ -1326,6 +1311,39 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
       memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
     ```
   }];
+  let builders = [
+    // Builders for a contracting reshape whose result type is computed from
+    // `src` and `reassociation`.
+    OpBuilder<(ins "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    OpBuilder<(ins "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      auto reassociationMaps =
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, src, reassociationMaps, attrs);
+    }]>,
+
+    // Builders for a reshape whose result type is passed explicitly.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+                          getReassociationIndicesAttribute($_builder, reassociation));
+    }]>,
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      auto reassociationMaps =
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+    }]>
+  ];
   let extraClassDeclaration = commonExtraClassDeclaration;
   let hasVerifier = 1;
 }

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2021231f28f24..30c74837efa79 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -678,41 +678,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Tensor_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
     Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyTensor:$result)> {
-  let builders = [
-    // Builders for a contracting reshape whose result type is computed from
-    // `src` and `reassociation`.
-    OpBuilder<(ins "Value":$src,
-      "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
-    OpBuilder<(ins "Value":$src,
-      "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
-    [{
-      auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, src, reassociationMaps, attrs);
-    }]>,
-
-    // Builders for a reshape whose result type is passed explicitly. This may
-    // be either a contracting or expanding reshape.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
-    [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-          getReassociationIndicesAttribute($_builder, reassociation));
-    }]>,
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
-    [{
-      auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
-    }]>
-  ];
-
+ 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrName() { return "reassociation"; }
     SmallVector<AffineMap, 4> getReassociationMaps();
@@ -768,6 +734,26 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
         : tensor<?x?xf32> into tensor<?x?x?xf32>
     ```
   }];
+  let builders = [
+    // Builders using ReassociationIndices.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+          getReassociationIndicesAttribute($_builder, reassociation));
+    }]>,
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      auto reassociationMaps =
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+    }]>
+  ];
+
   let extraClassDeclaration = commonExtraClassDeclaration;
   let hasVerifier = 1;
 }
@@ -797,6 +783,40 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
         : tensor<?x?x?xf32> into tensor<?x?xf32>
     ```
   }];
+  let builders = [
+    // Builders for a contracting reshape whose result type is computed from
+    // `src` and `reassociation`.
+    OpBuilder<(ins "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    OpBuilder<(ins "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      auto reassociationMaps =
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, src, reassociationMaps, attrs);
+    }]>,
+
+    // Builders for a reshape whose result type is passed explicitly.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+          getReassociationIndicesAttribute($_builder, reassociation));
+    }]>,
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      auto reassociationMaps =
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+    }]>
+  ];
+
   let extraClassDeclaration = commonExtraClassDeclaration;
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 26baea0f16c24..4effa26cb75cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -2223,12 +2223,11 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
 
 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns) {
-  patterns
-      .add<FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
-           FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
-           FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
-           FoldConsumerReshapeOpByLinearization<false, tensor::ExpandShapeOp>>(
-          patterns.getContext());
+  patterns.add<
+      FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
+      FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
+      FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>>(
+      patterns.getContext());
 }
 
 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
@@ -2236,8 +2235,7 @@ void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
   patterns
       .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
            FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
-           FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
-           FoldConsumerReshapeOpByLinearization<true, tensor::ExpandShapeOp>>(
+           FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>>(
           patterns.getContext());
 }
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 7a5d87172ddd2..93dd6f052b4dc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1669,18 +1669,6 @@ computeReshapeCollapsedType(MemRefType type,
           AffineMapAttr::get(layout)));
 }
 
-void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
-                          ArrayRef<ReassociationIndices> reassociation,
-                          ArrayRef<NamedAttribute> attrs) {
-  auto memRefType = src.getType().cast<MemRefType>();
-  auto resultType = computeReshapeCollapsedType(
-      memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
-                      b.getContext(), reassociation)));
-  build(b, result, resultType, src, attrs);
-  result.addAttribute(getReassociationAttrName(),
-                      getReassociationIndicesAttribute(b, reassociation));
-}
-
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2789fd3b2cd9a..b95a214034076 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -817,18 +817,6 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                       getReassociationIndicesAttribute(b, reassociation));
 }
 
-void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
-                          ArrayRef<ReassociationIndices> reassociation,
-                          ArrayRef<NamedAttribute> attrs) {
-  auto resultType = computeTensorReshapeCollapsedType(
-      src.getType().cast<RankedTensorType>(),
-      getSymbolLessAffineMaps(
-          convertReassociationIndicesToExprs(b.getContext(), reassociation)));
-  build(b, result, resultType, src, attrs);
-  result.addAttribute(getReassociationAttrName(),
-                      getReassociationIndicesAttribute(b, reassociation));
-}
-
 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
                                         TensorReshapeOp, ExpandShapeOp>::value>
 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,


        


More information about the Mlir-commits mailing list