[Mlir-commits] [mlir] f68939d - [MLIR] Tighten type constraint on memref.global op def

Uday Bondhugula llvmlistbot at llvm.org
Wed Sep 15 10:12:00 PDT 2021


Author: Uday Bondhugula
Date: 2021-09-15T22:41:03+05:30
New Revision: f68939d3d91c3e1b57fba5450fa9146c3dcf5fdc

URL: https://github.com/llvm/llvm-project/commit/f68939d3d91c3e1b57fba5450fa9146c3dcf5fdc
DIFF: https://github.com/llvm/llvm-project/commit/f68939d3d91c3e1b57fba5450fa9146c3dcf5fdc.diff

LOG: [MLIR] Tighten type constraint on memref.global op def

Tighten the def of memref.global op to use the right kind of TypeAttr
(of MemRefType).

Differential Revision: https://reviews.llvm.org/D109822

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c6c3ac8e12669..e0cb3816efafe 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -18,6 +18,12 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 
+/// A TypeAttr for memref types.
+def MemRefTypeAttr
+    : TypeAttrBase<"::mlir::MemRefType", "memref type attribute"> {
+  let constBuilderCall = "::mlir::TypeAttr::get($0)";
+}
+
 class MemRef_Op<string mnemonic, list<OpTrait> traits = []>
     : Op<MemRef_Dialect, mnemonic, traits> {
   let printer = [{ return ::print(p, *this); }];
@@ -597,14 +603,14 @@ def MemRef_GetGlobalOp : MemRef_Op<"get_global",
 def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
   let summary = "declare or define a global memref variable";
   let description = [{
-    The `memref.global` operation declares or defines a named global variable.
-    The backing memory for the variable is allocated statically and is described
-    by the type of the variable (which should be a statically shaped memref
-    type). The operation is a declaration if no `inital_value` is specified,
-    else it is a definition. The `initial_value` can either be a unit attribute
-    to represent a definition of an uninitialized global variable, or an
-    elements attribute to represent the definition of a global variable with an
-    initial value. The global variable can also be marked constant using the
+    The `memref.global` operation declares or defines a named global memref
+    variable. The backing memory for the variable is allocated statically and is
+    described by the type of the variable (which should be a statically shaped
+    memref type). The operation is a declaration if no `inital_value` is
+    specified, else it is a definition. The `initial_value` can either be a unit
+    attribute to represent a definition of an uninitialized global variable, or
+    an elements attribute to represent the definition of a global variable with
+    an initial value. The global variable can also be marked constant using the
     `constant` unit attribute. Writing to such constant global variables is
     undefined.
 
@@ -633,7 +639,7 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
   let arguments = (ins
       SymbolNameAttr:$sym_name,
       OptionalAttr<StrAttr>:$sym_visibility,
-      TypeAttr:$type,
+      MemRefTypeAttr:$type,
       OptionalAttr<AnyAttr>:$initial_value,
       UnitAttr:$constant
   );

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index ea3c9943f4b46..ebca204ab8486 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -434,7 +434,7 @@ struct GlobalMemrefOpLowering
   LogicalResult
   matchAndRewrite(memref::GlobalOp global, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefType type = global.type().cast<MemRefType>();
+    MemRefType type = global.type();
     if (!isConvertibleAndHasIdentityMaps(type))
       return failure();
 

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
index 518405aabb49f..c916a73e16d9c 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
@@ -46,7 +46,7 @@ memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) {
   auto global = globalBuilder.create<memref::GlobalOp>(
       constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
       /*sym_visibility=*/globalBuilder.getStringAttr("private"),
-      /*type=*/typeConverter.convertType(type),
+      /*type=*/typeConverter.convertType(type).cast<MemRefType>(),
       /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
       /*constant=*/true);
   symbolTable.insert(global);


        


More information about the Mlir-commits mailing list