[Mlir-commits] [mlir] 657f68b - [NFC][mlir][MemRef] Make use of InferTypeOpInterface

Quentin Colombet llvmlistbot at llvm.org
Fri Oct 14 11:52:00 PDT 2022


Author: Quentin Colombet
Date: 2022-10-14T18:49:37Z
New Revision: 657f68b1f2fd38deb63c23d8f46d12b7fd357e63

URL: https://github.com/llvm/llvm-project/commit/657f68b1f2fd38deb63c23d8f46d12b7fd357e63
DIFF: https://github.com/llvm/llvm-project/commit/657f68b1f2fd38deb63c23d8f46d12b7fd357e63.diff

LOG: [NFC][mlir][MemRef] Make use of InferTypeOpInterface

The `InferTypeOpInterface` generates builders for things it can infer
the types.
Thanks to that interface we can:
- Eliminate a builder for DimOp, and
- Describe how to infer the result types of `extract_strided_metadata`
  from its source, and get a simpler builder as a result

NFC

Differential Revision: https://reviews.llvm.org/D135734

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index bd99cf2750fe9..6538abceb6293 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -15,6 +15,7 @@
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 54394dadafcad..ba8fe8103269c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -581,7 +582,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
 
   let builders = [
     OpBuilder<(ins "Value":$source, "int64_t":$index)>,
-    OpBuilder<(ins "Value":$source, "Value":$index)>
   ];
 
   let extraClassDeclaration = [{
@@ -853,7 +853,8 @@ def MemRef_ExtractAlignedPointerAsIndexOp :
 def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     Pure, 
-    SameVariadicResultSize]> {
+    SameVariadicResultSize,
+    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Extracts a buffer base with offset and strides";
   let description = [{
     Extracts a base buffer, offset and strides. This op allows additional layers

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index fbc1eadffcba3..ab7311b3d101f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -807,12 +807,6 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
   build(builder, result, source, indexValue);
 }
 
-void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
-                  Value index) {
-  auto indexTy = builder.getIndexType();
-  build(builder, result, indexTy, source, index);
-}
-
 Optional<int64_t> DimOp::getConstantIndex() {
   if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
     return constantOp.getValue().cast<IntegerAttr>().getInt();
@@ -1254,6 +1248,32 @@ void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
 // ExtractStridedMetadataOp
 //===----------------------------------------------------------------------===//
 
+/// The number and type of the results are inferred from the
+/// shape of the source.
+LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, regions);
+  auto sourceType = extractAdaptor.getSource().getType().dyn_cast<MemRefType>();
+  if (!sourceType)
+    return failure();
+
+  unsigned sourceRank = sourceType.getRank();
+  IndexType indexType = IndexType::get(context);
+  auto memrefType =
+      MemRefType::get({}, sourceType.getElementType(),
+                      MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
+  // Base.
+  inferredReturnTypes.push_back(memrefType);
+  // Offset.
+  inferredReturnTypes.push_back(indexType);
+  // Sizes and strides.
+  for (unsigned i = 0; i < sourceRank * 2; ++i)
+    inferredReturnTypes.push_back(indexType);
+  return success();
+}
+
 void ExtractStridedMetadataOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getBaseBuffer(), "base_buffer");

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 257a02b9155d7..2a8ffba9c32c4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -59,16 +59,12 @@ struct ExtractStridedMetadataOpSubviewFolder
     // Build a plain extract_strided_metadata(memref) from
     // extract_strided_metadata(subview(memref)).
     Location origLoc = op.getLoc();
-    IndexType indexType = rewriter.getIndexType();
     Value source = subview.getSource();
     auto sourceType = source.getType().cast<MemRefType>();
     unsigned sourceRank = sourceType.getRank();
-    SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
 
     auto newExtractStridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(
-            origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
-            sizeStrideTypes, source);
+        rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
 
     SmallVector<int64_t> sourceStrides;
     int64_t sourceOffset;
@@ -486,16 +482,12 @@ struct ExtractStridedMetadataOpReshapeFolder
     // Build a plain extract_strided_metadata(memref) from
     // extract_strided_metadata(reassociative_reshape_like(memref)).
     Location origLoc = op.getLoc();
-    IndexType indexType = rewriter.getIndexType();
     Value source = reshape.getSrc();
     auto sourceType = source.getType().cast<MemRefType>();
     unsigned sourceRank = sourceType.getRank();
-    SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
 
     auto newExtractStridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(
-            origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
-            sizeStrideTypes, source);
+        rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
 
     // Collect statically known information.
     SmallVector<int64_t> strides;


        


More information about the Mlir-commits mailing list