[Mlir-commits] [mlir] e07d749 - [mlir][linalg] Propagate attributes when doing named ops conversion.

Hanhan Wang llvmlistbot at llvm.org
Thu Sep 15 11:45:30 PDT 2022


Author: Hanhan Wang
Date: 2022-09-15T11:44:52-07:00
New Revision: e07d749e7df07381ed90d564020eda0c1c89d7e3

URL: https://github.com/llvm/llvm-project/commit/e07d749e7df07381ed90d564020eda0c1c89d7e3
DIFF: https://github.com/llvm/llvm-project/commit/e07d749e7df07381ed90d564020eda0c1c89d7e3.diff

LOG: [mlir][linalg] Propagate attributes when doing named ops conversion.

Custom attributes can be set on the operation. It prevents them to be
removed when doing named ops conversion.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D133892

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
    mlir/test/Dialect/Linalg/namedop_conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 6f31e6a3abeae..807e63acc1a4a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringSet.h"
 
 namespace mlir {
 class AffineExpr;
@@ -495,6 +496,23 @@ struct GenerateLoopNest {
                    ArrayRef<linalg::ProcInfo> procInfo = {});
 };
 
+/// Returns an attribute list that excludes pre-defined attributes.
+template <typename OpTy>
+SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
+  llvm::StringSet<> elidedAttrs;
+  elidedAttrs.insert(op.getAttributeNames().begin(),
+                     op.getAttributeNames().end());
+  if (isa<linalg::LinalgOp>(op.getOperation()))
+    elidedAttrs.insert(LinalgDialect::kMemoizedIndexingMapsAttrName);
+  SmallVector<NamedAttribute> attrs;
+  for (auto attr : op->getAttrs()) {
+    if (elidedAttrs.count(attr.getName()))
+      continue;
+    attrs.push_back(attr);
+  }
+  return attrs;
+}
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
index 273518f8c3824..c2a8dbe000c2d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_LINALGNAMEDOPCONVERSION
@@ -72,28 +73,30 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
   auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
       loc, newInitTy, init, collapsedInitDims);
 
-  Value newConv;
-  if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
-    newConv = rewriter
-                  .create<DepthwiseConv2DNhwcHwcOp>(
-                      loc, newInitTy, ValueRange{input, collapsedKernel},
-                      ValueRange{collapsedInit}, stride, dilation)
-                  .getResult(0);
-  } else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
-    newConv =
-        rewriter
-            .create<DepthwiseConv2DNhwcHwcQOp>(
+  SmallVector<NamedAttribute> preservedAttrs;
+  Operation *newConv =
+      TypeSwitch<Operation *, Operation *>(operation)
+          .Case<DepthwiseConv2DNhwcHwcmOp>([&](auto op) {
+            preservedAttrs = getPrunedAttributeList(op);
+            return rewriter.create<DepthwiseConv2DNhwcHwcOp>(
+                loc, newInitTy, ValueRange{input, collapsedKernel},
+                ValueRange{collapsedInit}, stride, dilation);
+          })
+          .Case<DepthwiseConv2DNhwcHwcmQOp>([&](auto op) {
+            preservedAttrs = getPrunedAttributeList(op);
+            return rewriter.create<DepthwiseConv2DNhwcHwcQOp>(
                 loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
-                ValueRange{collapsedInit}, stride, dilation)
-            .getResult(0);
-  }
-
+                ValueRange{collapsedInit}, stride, dilation);
+          })
+          .Default([](Operation *op) { return nullptr; });
   if (!newConv)
     return failure();
+  for (auto attr : preservedAttrs)
+    newConv->setAttr(attr.getName(), attr.getValue());
 
   // Expand dimensions back out to
   rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
-      operation, resultTy, newConv, collapsedInitDims);
+      operation, resultTy, newConv->getResult(0), collapsedInitDims);
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Linalg/namedop_conversion.mlir b/mlir/test/Dialect/Linalg/namedop_conversion.mlir
index 8b779f2e496ba..4f2f2720037cd 100644
--- a/mlir/test/Dialect/Linalg/namedop_conversion.mlir
+++ b/mlir/test/Dialect/Linalg/namedop_conversion.mlir
@@ -4,9 +4,9 @@
 func.func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
   // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
   // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
-  // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
+  // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
   // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
-  %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
   return %0 : tensor<?x?x?x?x1xf32>
 }
 
@@ -17,8 +17,8 @@ func.func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>
 func.func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
   // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
   // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
-  // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
+  // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
   // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
-  %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
   return %0 : tensor<?x?x?x?x1xi32>
 }


        


More information about the Mlir-commits mailing list