[Mlir-commits] [mlir] c780184 - [MLIR][Transform] Expose map layout option in `OneShotBufferizeOp`

Lorenzo Chelini llvmlistbot at llvm.org
Mon Nov 14 09:09:59 PST 2022


Author: Lorenzo Chelini
Date: 2022-11-14T18:09:54+01:00
New Revision: c780184a84c4c6bab2e15065c3cde470e4c72cd0

URL: https://github.com/llvm/llvm-project/commit/c780184a84c4c6bab2e15065c3cde470e4c72cd0
DIFF: https://github.com/llvm/llvm-project/commit/c780184a84c4c6bab2e15065c3cde470e4c72cd0.diff

LOG: [MLIR][Transform] Expose map layout option in `OneShotBufferizeOp`

Expose `function-boundary-type-conversion` in `OneShotBufferizeOp`. To
reuse options between passes and transform operations, create a
`BufferizationEnums.td`.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D137833

Added: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
    mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
    mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
    mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
    mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
    mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 2c9dd66c45c46..a5324e1345af1 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -14,6 +14,8 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/SetVector.h"
 
+#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+
 namespace mlir {
 class OpBuilder;
 
@@ -187,12 +189,6 @@ struct BufferizationOptions {
   using UnknownTypeConverterFn = std::function<BaseMemRefType(
       Value, unsigned, const BufferizationOptions &)>;
 
-  enum class LayoutMapOption : int8_t {
-    InferLayoutMap = 0,
-    IdentityLayoutMap = 1,
-    FullyDynamicLayoutMap = 2
-  };
-
   BufferizationOptions();
 
   /// Try to cast the given op to BufferizableOpInterface if the op is allow
@@ -585,6 +581,10 @@ bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
 } // namespace bufferization
 } // namespace mlir
 
+//===----------------------------------------------------------------------===//
+// Bufferization Interfaces
+//===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"
 
 #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td
new file mode 100644
index 0000000000000..92423614e85b1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td
@@ -0,0 +1,27 @@
+//===- BufferizationEnums.td - Bufferization enums ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for enums used in Bufferization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUFFERIZATION_ENUMS
+#define BUFFERIZATION_ENUMS
+
+include "mlir/IR/EnumAttr.td"
+
+def LayoutMapOption : I32EnumAttr<"LayoutMapOption", 
+                                  "option for map layout", [
+  I32EnumAttrCase<"InferLayoutMap", 0>,
+  I32EnumAttrCase<"IdentityLayoutMap", 1>,
+  I32EnumAttrCase<"FullyDynamicLayoutMap", 2>
+]> {
+  let cppNamespace = "::mlir::bufferization";
+}
+
+#endif // BUFFERIZATION_ENUMS

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 8ddfe5a384c06..aa93534a78fea 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -2,3 +2,9 @@ add_mlir_dialect(BufferizationOps bufferization)
 add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
 add_mlir_interface(AllocationOpInterface)
 add_mlir_interface(BufferizableOpInterface)
+
+set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
+mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
+mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRBufferizationEnumsIncGen)
+add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen)

diff  --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
index bc51845b96064..0aab581ee12d4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
+++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
 
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpImplementation.h"

diff  --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
index 72e679604c3b0..e63ecbf90dfa1 100644
--- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
@@ -9,6 +9,7 @@
 #ifndef BUFFERIZATION_TRANSFORM_OPS
 #define BUFFERIZATION_TRANSFORM_OPS
 
+include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
@@ -42,6 +43,7 @@ def OneShotBufferizeOp
 
   let arguments = (
       ins PDL_Operation:$target,
+      OptionalAttr<LayoutMapOption>:$function_boundary_type_conversion,
       DefaultValuedAttr<BoolAttr, "false">:$allow_return_allocs,
       DefaultValuedAttr<BoolAttr, "false">:$allow_unknown_ops,
       DefaultValuedAttr<BoolAttr, "false">:$bufferize_function_boundaries,
@@ -52,7 +54,10 @@ def OneShotBufferizeOp
 
   let results = (outs);
 
-  let assemblyFormat = "$target attr-dict";
+  let assemblyFormat = [{
+    (`layout` `{` $function_boundary_type_conversion^ `}`)?
+    $target attr-dict
+  }];
 }
 
 #endif // BUFFERIZATION_TRANSFORM_OPS

diff  --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 0c085a40adcf1..e77414067e3a8 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   DEPENDS
   MLIRAllocationOpInterfaceIncGen
   MLIRBufferizationOpsIncGen
+  MLIRBufferizationEnumsIncGen
 
   LINK_LIBS PUBLIC
   MLIRAffineDialect

diff  --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index fc3c386d74a4a..9415bf792816b 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -34,6 +34,9 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
   options.createDeallocs = getCreateDeallocs();
   options.testAnalysisOnly = getTestAnalysisOnly();
   options.printConflicts = getPrintConflicts();
+  if (getFunctionBoundaryTypeConversion().has_value())
+    options.functionBoundaryTypeConversion =
+        *getFunctionBoundaryTypeConversion();
 
   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
   for (Operation *target : payloadOps) {
@@ -94,6 +97,8 @@ class BufferizationTransformDialectExtension
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
 
+#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
+
 void mlir::bufferization::registerTransformDialectExtension(
     DialectRegistry &registry) {
   registry.addExtensions<BufferizationTransformDialectExtension>();

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index e4355ad996021..7c3b7c8fcc6a9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -163,14 +163,13 @@ struct FinalizingBufferizePass
   }
 };
 
-static BufferizationOptions::LayoutMapOption
-parseLayoutMapOption(const std::string &s) {
+static LayoutMapOption parseLayoutMapOption(const std::string &s) {
   if (s == "fully-dynamic-layout-map")
-    return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap;
+    return LayoutMapOption::FullyDynamicLayoutMap;
   if (s == "identity-layout-map")
-    return BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
+    return LayoutMapOption::IdentityLayoutMap;
   if (s == "infer-layout-map")
-    return BufferizationOptions::LayoutMapOption::InferLayoutMap;
+    return LayoutMapOption::InferLayoutMap;
   llvm_unreachable("invalid layout map option");
 }
 
@@ -216,19 +215,17 @@ struct OneShotBufferizePass
       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
 
       // Configure type converter.
-      BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
+      LayoutMapOption unknownTypeConversionOption =
           parseLayoutMapOption(unknownTypeConversion);
       opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
                                        const BufferizationOptions &options) {
         auto tensorType = value.getType().cast<TensorType>();
-        if (unknownTypeConversionOption ==
-            BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
+        if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
           return bufferization::getMemRefTypeWithStaticIdentityLayout(
               tensorType, memorySpace);
-        assert(
-            unknownTypeConversionOption ==
-                BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
-            "invalid layout map option");
+        assert(unknownTypeConversionOption ==
+                   LayoutMapOption::FullyDynamicLayoutMap &&
+               "invalid layout map option");
         return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
                                                                   memorySpace);
       };

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 453d71f1d5b58..e23c5c3c51b5a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
 
   DEPENDS
   MLIRBufferizationPassIncGen
+  MLIRBufferizationEnumsIncGen
 
   LINK_LIBS PUBLIC
   MLIRBufferizationDialect

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 49c57f4921878..91060dd6b1394 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -69,7 +69,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
 
   BaseMemRefType memrefType;
   if (options.functionBoundaryTypeConversion ==
-      BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
+      LayoutMapOption::IdentityLayoutMap) {
     memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
   } else {
     // Note: Layout maps on function parameters cannot be inferred. The best we
@@ -471,7 +471,7 @@ struct FuncOpInterface
 
       BaseMemRefType resultType;
       if (options.functionBoundaryTypeConversion ==
-          BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
+          LayoutMapOption::IdentityLayoutMap) {
         resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
       } else {
         // Note: If `InferLayoutMap`, cast are later folded away.

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index badcf292805b5..fb1d50c466f9c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -423,7 +423,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
       return failure();
     // Change buffer return types to more precise layout maps.
     if (options.functionBoundaryTypeConversion ==
-        BufferizationOptions::LayoutMapOption::InferLayoutMap)
+        LayoutMapOption::InferLayoutMap)
       foldMemRefCasts(funcOp);
   }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 478bac5b24d6d..0e1fbafc5f760 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -32,8 +32,7 @@ getBufferizationOptions(bool analysisOnly) {
   // TODO(springerm): To spot memory leaks more easily, returning dense allocs
   // should be disallowed.
   options.allowReturnAllocs = true;
-  options.functionBoundaryTypeConversion =
-      BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
+  options.functionBoundaryTypeConversion = LayoutMapOption::IdentityLayoutMap;
   options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
                                       const BufferizationOptions &options) {
     return getMemRefTypeWithStaticIdentityLayout(

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
index 151c8e6996319..4ff8a23a769fd 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
@@ -96,3 +96,25 @@ module {
     return %0 : tensor<?xf32>
   }
 }
+
+// -----
+
+// Test we use identity layout at function boundaries.
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+  transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %arg1 {
+    target_is_module = true,
+    bufferize_function_boundaries = true }
+}
+
+// CHECK: func.func @matmul(
+// CHECK-SAME:  %[[A:.*]]: memref<12x9xf32>,
+// CHECK-SAME:  %[[B:.*]]: memref<9x6xf32>,
+// CHECK-SAME:  %[[C:.*]]: memref<12x6xf32>) -> memref<12x6xf32> {
+func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32>) -> tensor<12x6xf32> {
+  // CHECK: linalg.matmul ins(%[[A]], %[[B]] : memref<12x9xf32>, memref<9x6xf32>) outs(%[[C]] : memref<12x6xf32>)
+  %D = linalg.matmul ins(%A, %B: tensor<12x9xf32>, tensor<9x6xf32>) outs(%C: tensor<12x6xf32>) -> tensor<12x6xf32>
+  // CHECK: return %[[C]] : memref<12x6xf32>
+  return %D : tensor<12x6xf32>
+}


        


More information about the Mlir-commits mailing list