[Mlir-commits] [mlir] 657f68b - [NFC][mlir][MemRef] Make use of InferTypeOpInterface
Quentin Colombet
llvmlistbot at llvm.org
Fri Oct 14 11:52:00 PDT 2022
Author: Quentin Colombet
Date: 2022-10-14T18:49:37Z
New Revision: 657f68b1f2fd38deb63c23d8f46d12b7fd357e63
URL: https://github.com/llvm/llvm-project/commit/657f68b1f2fd38deb63c23d8f46d12b7fd357e63
DIFF: https://github.com/llvm/llvm-project/commit/657f68b1f2fd38deb63c23d8f46d12b7fd357e63.diff
LOG: [NFC][mlir][MemRef] Make use of InferTypeOpInterface
The `InferTypeOpInterface` generates builders for things it can infer
the types.
Thanks to that interface we can:
- Eliminate a builder for DimOp, and
- Describe how to infer the result types of `extract_strided_metadata`
from its source, and get a simpler builder as a result
NFC
Differential Revision: https://reviews.llvm.org/D135734
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index bd99cf2750fe9..6538abceb6293 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -15,6 +15,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 54394dadafcad..ba8fe8103269c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -581,7 +582,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
let builders = [
OpBuilder<(ins "Value":$source, "int64_t":$index)>,
- OpBuilder<(ins "Value":$source, "Value":$index)>
];
let extraClassDeclaration = [{
@@ -853,7 +853,8 @@ def MemRef_ExtractAlignedPointerAsIndexOp :
def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure,
- SameVariadicResultSize]> {
+ SameVariadicResultSize,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Extracts a buffer base with offset and strides";
let description = [{
Extracts a base buffer, offset and strides. This op allows additional layers
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index fbc1eadffcba3..ab7311b3d101f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -807,12 +807,6 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
build(builder, result, source, indexValue);
}
-void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
- Value index) {
- auto indexTy = builder.getIndexType();
- build(builder, result, indexTy, source, index);
-}
-
Optional<int64_t> DimOp::getConstantIndex() {
if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
return constantOp.getValue().cast<IntegerAttr>().getInt();
@@ -1254,6 +1248,32 @@ void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
// ExtractStridedMetadataOp
//===----------------------------------------------------------------------===//
+/// The number and type of the results are inferred from the
+/// shape of the source.
+LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, regions);
+ auto sourceType = extractAdaptor.getSource().getType().dyn_cast<MemRefType>();
+ if (!sourceType)
+ return failure();
+
+ unsigned sourceRank = sourceType.getRank();
+ IndexType indexType = IndexType::get(context);
+ auto memrefType =
+ MemRefType::get({}, sourceType.getElementType(),
+ MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
+ // Base.
+ inferredReturnTypes.push_back(memrefType);
+ // Offset.
+ inferredReturnTypes.push_back(indexType);
+ // Sizes and strides.
+ for (unsigned i = 0; i < sourceRank * 2; ++i)
+ inferredReturnTypes.push_back(indexType);
+ return success();
+}
+
void ExtractStridedMetadataOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getBaseBuffer(), "base_buffer");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 257a02b9155d7..2a8ffba9c32c4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -59,16 +59,12 @@ struct ExtractStridedMetadataOpSubviewFolder
// Build a plain extract_strided_metadata(memref) from
// extract_strided_metadata(subview(memref)).
Location origLoc = op.getLoc();
- IndexType indexType = rewriter.getIndexType();
Value source = subview.getSource();
auto sourceType = source.getType().cast<MemRefType>();
unsigned sourceRank = sourceType.getRank();
- SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(
- origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
- sizeStrideTypes, source);
+ rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
SmallVector<int64_t> sourceStrides;
int64_t sourceOffset;
@@ -486,16 +482,12 @@ struct ExtractStridedMetadataOpReshapeFolder
// Build a plain extract_strided_metadata(memref) from
// extract_strided_metadata(reassociative_reshape_like(memref)).
Location origLoc = op.getLoc();
- IndexType indexType = rewriter.getIndexType();
Value source = reshape.getSrc();
auto sourceType = source.getType().cast<MemRefType>();
unsigned sourceRank = sourceType.getRank();
- SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(
- origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
- sizeStrideTypes, source);
+ rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
// Collect statically known information.
SmallVector<int64_t> strides;
More information about the Mlir-commits
mailing list