[Mlir-commits] [mlir] d09d0bd - [mlir][NFC] Sort the operations alphabetically and add header blocks
River Riddle
llvmlistbot at llvm.org
Wed Mar 4 09:53:20 PST 2020
Author: River Riddle
Date: 2020-03-04T09:48:55-08:00
New Revision: d09d0bd7a019d4300d4c0a4000fd0f208b838e0a
URL: https://github.com/llvm/llvm-project/commit/d09d0bd7a019d4300d4c0a4000fd0f208b838e0a
DIFF: https://github.com/llvm/llvm-project/commit/d09d0bd7a019d4300d4c0a4000fd0f208b838e0a.diff
LOG: [mlir][NFC] Sort the operations alphabetically and add header blocks
Summary:
The order of the operations has fallen out of sync as operations have been renamed and new ones have been added.
Differential Revision: https://reviews.llvm.org/D75540
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 08a81b385bf3..a9217b30bd28 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -121,6 +121,10 @@ class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
ArithmeticOp<mnemonic, traits>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
+//===----------------------------------------------------------------------===//
+// AbsFOp
+//===----------------------------------------------------------------------===//
+
def AbsFOp : FloatUnaryOp<"absf"> {
let summary = "floating point absolute-value operation";
let description = [{
@@ -131,16 +135,28 @@ def AbsFOp : FloatUnaryOp<"absf"> {
}];
}
+//===----------------------------------------------------------------------===//
+// AddFOp
+//===----------------------------------------------------------------------===//
+
def AddFOp : FloatArithmeticOp<"addf"> {
let summary = "floating point addition operation";
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// AddIOp
+//===----------------------------------------------------------------------===//
+
def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
let summary = "integer addition operation";
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// AllocOp
+//===----------------------------------------------------------------------===//
+
def AllocOp : Std_Op<"alloc"> {
let summary = "memory allocation operation";
let description = [{
@@ -213,11 +229,40 @@ def AllocOp : Std_Op<"alloc"> {
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// AndOp
+//===----------------------------------------------------------------------===//
+
def AndOp : IntArithmeticOp<"and", [Commutative]> {
let summary = "integer binary and";
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// AssumeAlignmentOp
+//===----------------------------------------------------------------------===//
+
+def AssumeAlignmentOp : Std_Op<"assume_alignment"> {
+ let summary =
+ "assertion that gives alignment information to the input memref";
+ let description = [{
+ The assume alignment operation takes a memref and a integer of alignment
+ value, and internally annotates the buffer with the given alignment. If
+ the buffer isn't aligned to the given alignment, the behavior is undefined.
+
+ This operation doesn't affect the semantics of a correct program. It's for
+ optimization only, and the optimization is best-effort.
+ }];
+ let arguments = (ins AnyMemRef:$memref, PositiveI32Attr:$alignment);
+ let results = (outs);
+
+ let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
+}
+
+//===----------------------------------------------------------------------===//
+// AtomicRMWOp
+//===----------------------------------------------------------------------===//
+
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
@@ -281,6 +326,10 @@ def AtomicRMWOp : Std_Op<"atomic_rmw", [
}];
}
+//===----------------------------------------------------------------------===//
+// BranchOp
+//===----------------------------------------------------------------------===//
+
def BranchOp : Std_Op<"br", [Terminator]> {
let summary = "branch operation";
let description = [{
@@ -316,6 +365,10 @@ def BranchOp : Std_Op<"br", [Terminator]> {
let assemblyFormat = "$dest attr-dict";
}
+//===----------------------------------------------------------------------===//
+// CallOp
+//===----------------------------------------------------------------------===//
+
def CallOp : Std_Op<"call", [CallOpInterface]> {
let summary = "call operation";
let description = [{
@@ -372,6 +425,10 @@ def CallOp : Std_Op<"call", [CallOpInterface]> {
}];
}
+//===----------------------------------------------------------------------===//
+// CallIndirectOp
+//===----------------------------------------------------------------------===//
+
def CallIndirectOp : Std_Op<"call_indirect", [
CallOpInterface,
TypesMatchWith<"callee input types match argument types",
@@ -423,6 +480,10 @@ def CallIndirectOp : Std_Op<"call_indirect", [
let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)";
}
+//===----------------------------------------------------------------------===//
+// CeilFOp
+//===----------------------------------------------------------------------===//
+
def CeilFOp : FloatUnaryOp<"ceilf"> {
let summary = "ceiling of the specified value";
let description = [{
@@ -433,6 +494,10 @@ def CeilFOp : FloatUnaryOp<"ceilf"> {
}];
}
+//===----------------------------------------------------------------------===//
+// CmpFOp
+//===----------------------------------------------------------------------===//
+
// The predicate indicates the type of the comparison to perform:
// (un)orderedness, (in)equality and less/greater than (or equal to) as
// well as predicates that are always true or false.
@@ -519,6 +584,10 @@ def CmpFOp : Std_Op<"cmpf",
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
}
+//===----------------------------------------------------------------------===//
+// CmpIOp
+//===----------------------------------------------------------------------===//
+
def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>;
def CMPI_P_NE : I64EnumAttrCase<"ne", 1>;
def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>;
@@ -594,6 +663,10 @@ def CmpIOp : Std_Op<"cmpi",
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
}
+//===----------------------------------------------------------------------===//
+// CondBranchOp
+//===----------------------------------------------------------------------===//
+
def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
let summary = "conditional branch operation";
let description = [{
@@ -705,6 +778,10 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
let assemblyFormat = "$condition `,` successors attr-dict";
}
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
def ConstantOp : Std_Op<"constant",
[NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
let summary = "constant";
@@ -727,6 +804,10 @@ def ConstantOp : Std_Op<"constant",
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// CopySignOp
+//===----------------------------------------------------------------------===//
+
def CopySignOp : FloatArithmeticOp<"copysign"> {
let summary = "A copysign operation";
let description = [{
@@ -738,6 +819,10 @@ def CopySignOp : FloatArithmeticOp<"copysign"> {
}];
}
+//===----------------------------------------------------------------------===//
+// CosOp
+//===----------------------------------------------------------------------===//
+
def CosOp : FloatUnaryOp<"cos"> {
let summary = "cosine of the specified value";
let description = [{
@@ -748,6 +833,10 @@ def CosOp : FloatUnaryOp<"cos"> {
}];
}
+//===----------------------------------------------------------------------===//
+// DeallocOp
+//===----------------------------------------------------------------------===//
+
def DeallocOp : Std_Op<"dealloc"> {
let summary = "memory deallocation operation";
let description = [{
@@ -768,6 +857,10 @@ def DeallocOp : Std_Op<"dealloc"> {
let assemblyFormat = "$memref attr-dict `:` type($memref)";
}
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+
def DimOp : Std_Op<"dim", [NoSideEffect]> {
let summary = "dimension index operation";
let description = [{
@@ -800,24 +893,26 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// DivFOp
+//===----------------------------------------------------------------------===//
+
def DivFOp : FloatArithmeticOp<"divf"> {
let summary = "floating point division operation";
}
-def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
- let summary = "signed integer division operation";
- let hasFolder = 1;
-}
-
-def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
- let summary = "unsigned integer division operation";
- let hasFolder = 1;
-}
+//===----------------------------------------------------------------------===//
+// ExpOp
+//===----------------------------------------------------------------------===//
def ExpOp : FloatUnaryOp<"exp"> {
let summary = "base-e exponential of the specified value";
}
+//===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
def ExtractElementOp : Std_Op<"extract_element",
[NoSideEffect,
TypesMatchWith<"result type matches element type of aggregate",
@@ -862,22 +957,9 @@ def ExtractElementOp : Std_Op<"extract_element",
}];
}
-def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
- let summary = "cast between index and integer types";
- let description = [{
- Casts between integer scalars and 'index' scalars. Index is an integer of
- platform-specific bit width. If casting to a wider integer, the value is
- sign-extended. If casting to a narrower integer, the value is truncated.
- }];
-
- let extraClassDeclaration = [{
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
- }];
-
- let hasFolder = 1;
-}
+//===----------------------------------------------------------------------===//
+// FPExtOp
+//===----------------------------------------------------------------------===//
def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> {
let summary = "cast from floating-point to wider floating-point";
@@ -896,6 +978,10 @@ def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> {
let hasFolder = 0;
}
+//===----------------------------------------------------------------------===//
+// FPTruncOp
+//===----------------------------------------------------------------------===//
+
def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
let summary = "cast from floating-point to narrower floating-point";
let description = [{
@@ -914,6 +1000,31 @@ def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
let hasFolder = 0;
}
+//===----------------------------------------------------------------------===//
+// IndexCastOp
+//===----------------------------------------------------------------------===//
+
+def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
+ let summary = "cast between index and integer types";
+ let description = [{
+ Casts between integer scalars and 'index' scalars. Index is an integer of
+ platform-specific bit width. If casting to a wider integer, the value is
+ sign-extended. If casting to a narrower integer, the value is truncated.
+ }];
+
+ let extraClassDeclaration = [{
+ /// Return true if `a` and `b` are valid operand and result pairs for
+ /// the operation.
+ static bool areCastCompatible(Type a, Type b);
+ }];
+
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
def LoadOp : Std_Op<"load",
[TypesMatchWith<"result type matches element type of 'memref'",
"memref", "result",
@@ -956,6 +1067,10 @@ def LoadOp : Std_Op<"load",
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
}
+//===----------------------------------------------------------------------===//
+// LogOp
+//===----------------------------------------------------------------------===//
+
def LogOp : FloatUnaryOp<"log"> {
let summary = "base-e logarithm of the specified value";
}
@@ -968,6 +1083,10 @@ def Log2Op : FloatUnaryOp<"log2"> {
let summary = "base-2 logarithm of the specified value";
}
+//===----------------------------------------------------------------------===//
+// MemRefCastOp
+//===----------------------------------------------------------------------===//
+
def MemRefCastOp : CastOp<"memref_cast"> {
let summary = "memref cast operation";
let description = [{
@@ -1022,16 +1141,28 @@ def MemRefCastOp : CastOp<"memref_cast"> {
}];
}
+//===----------------------------------------------------------------------===//
+// MulFOp
+//===----------------------------------------------------------------------===//
+
def MulFOp : FloatArithmeticOp<"mulf"> {
let summary = "floating point multiplication operation";
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// MulIOp
+//===----------------------------------------------------------------------===//
+
def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
let summary = "integer multiplication operation";
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// NegFOp
+//===----------------------------------------------------------------------===//
+
def NegFOp : FloatUnaryOp<"negf"> {
let summary = "floating point negation";
let description = [{
@@ -1042,11 +1173,19 @@ def NegFOp : FloatUnaryOp<"negf"> {
}];
}
+//===----------------------------------------------------------------------===//
+// OrOp
+//===----------------------------------------------------------------------===//
+
def OrOp : IntArithmeticOp<"or", [Commutative]> {
let summary = "integer binary or";
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// PrefetchOp
+//===----------------------------------------------------------------------===//
+
def PrefetchOp : Std_Op<"prefetch"> {
let summary = "prefetch operation";
let description = [{
@@ -1096,6 +1235,10 @@ def PrefetchOp : Std_Op<"prefetch"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
def RankOp : Std_Op<"rank", [NoSideEffect]> {
let summary = "rank operation";
let description = [{
@@ -1118,29 +1261,17 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
+//===----------------------------------------------------------------------===//
+// RemFOp
+//===----------------------------------------------------------------------===//
+
def RemFOp : FloatArithmeticOp<"remf"> {
let summary = "floating point division remainder operation";
}
-def RsqrtOp : FloatUnaryOp<"rsqrt"> {
- let summary = "reciprocal of sqrt (1 / sqrt of the specified value)";
- let description = [{
- The `rsqrt` operation computes the reciprocal of the square root. It takes
- one operand and returns one result of the same type. This type may be a
- float scalar type, a vector whose element type is float, or a tensor of
- floats. It has no standard attributes.
- }];
-}
-
-def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
- let summary = "signed integer division remainder operation";
- let hasFolder = 1;
-}
-
-def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
- let summary = "unsigned integer division remainder operation";
- let hasFolder = 1;
-}
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> {
let summary = "return operation";
@@ -1164,6 +1295,24 @@ def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> {
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
+//===----------------------------------------------------------------------===//
+// RsqrtOp
+//===----------------------------------------------------------------------===//
+
+def RsqrtOp : FloatUnaryOp<"rsqrt"> {
+ let summary = "reciprocal of sqrt (1 / sqrt of the specified value)";
+ let description = [{
+ The `rsqrt` operation computes the reciprocal of the square root. It takes
+ one operand and returns one result of the same type. This type may be a
+ float scalar type, a vector whose element type is float, or a tensor of
+ floats. It has no standard attributes.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
AllTypesMatch<["true_value", "false_value", "result"]>,
TypesMatchWith<"condition type matches i1 equivalent of result type",
@@ -1209,6 +1358,64 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
}];
}
+//===----------------------------------------------------------------------===//
+// ShiftLeftOp
+//===----------------------------------------------------------------------===//
+
+def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
+ let summary = "integer left-shift";
+ let description = [{
+ The shift_left operation shifts an integer value to the left by a variable
+ amount. The low order bits are filled with zeros.
+
+ %1 = constant 5 : i8 // %1 is 0b00000101
+ %2 = constant 3 : i8
+ %3 = shift_left %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// SignedDivIOp
+//===----------------------------------------------------------------------===//
+
+def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
+ let summary = "signed integer division operation";
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// SignedRemIOp
+//===----------------------------------------------------------------------===//
+
+def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
+ let summary = "signed integer division remainder operation";
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// SignedShiftRightOp
+//===----------------------------------------------------------------------===//
+
+def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
+ let summary = "signed integer right-shift";
+ let description = [{
+ The shift_right_signed operation shifts an integer value to the right by
+ a variable amount. The integer is interpreted as signed. The high order
+ bits in the output are filled with copies of the most-significant bit
+ of the shifted value (which means that the sign of the value is preserved).
+
+ %1 = constant 160 : i8 // %1 is 0b10100000
+ %2 = constant 3 : i8
+ %3 = shift_right_signed %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100
+ %4 = constant 96 : i8 // %4 is 0b01100000
+ %5 = shift_right_signed %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// SignExtendIOp
+//===----------------------------------------------------------------------===//
+
def SignExtendIOp : Std_Op<"sexti",
[NoSideEffect, SameOperandsAndResultShape]> {
let summary = "integer sign extension operation";
@@ -1244,46 +1451,9 @@ def SignExtendIOp : Std_Op<"sexti",
}];
}
-def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
- let summary = "integer left-shift";
- let description = [{
- The shift_left operation shifts an integer value to the left by a variable
- amount. The low order bits are filled with zeros.
-
- %1 = constant 5 : i8 // %1 is 0b00000101
- %2 = constant 3 : i8
- %3 = shift_left %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000
- }];
-}
-
-def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
- let summary = "signed integer right-shift";
- let description = [{
- The shift_right_signed operation shifts an integer value to the right by
- a variable amount. The integer is interpreted as signed. The high order
- bits in the output are filled with copies of the most-significant bit
- of the shifted value (which means that the sign of the value is preserved).
-
- %1 = constant 160 : i8 // %1 is 0b10100000
- %2 = constant 3 : i8
- %3 = shift_right_signed %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100
- %4 = constant 96 : i8 // %4 is 0b01100000
- %5 = shift_right_signed %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100
- }];
-}
-
-def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> {
- let summary = "unsigned integer right-shift";
- let description = [{
- The shift_right_unsigned operation shifts an integer value to the right by
- a variable amount. The integer is interpreted as unsigned. The high order
- bits are always filled with zeros.
-
- %1 = constant 160 : i8 // %1 is 0b10100000
- %2 = constant 3 : i8
- %3 = shift_right_unsigned %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100
- }];
-}
+//===----------------------------------------------------------------------===//
+// SIToFPOp
+//===----------------------------------------------------------------------===//
def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
let summary = "cast from integer type to floating-point";
@@ -1303,6 +1473,10 @@ def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
let hasFolder = 0;
}
+//===----------------------------------------------------------------------===//
+// SplatOp
+//===----------------------------------------------------------------------===//
+
def SplatOp : Std_Op<"splat", [NoSideEffect,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
@@ -1343,6 +1517,24 @@ def SplatOp : Std_Op<"splat", [NoSideEffect,
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
}
+//===----------------------------------------------------------------------===//
+// SqrtOp
+//===----------------------------------------------------------------------===//
+
+def SqrtOp : FloatUnaryOp<"sqrt"> {
+ let summary = "sqrt of the specified value";
+ let description = [{
+ The `sqrt` operation computes the square root. It takes one operand and
+ returns one result of the same type. This type may be a float scalar type, a
+ vector whose element type is float, or a tensor of floats. It has no standard
+ attributes.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
def StoreOp : Std_Op<"store",
[TypesMatchWith<"type of 'value' matches element type of 'memref'",
"memref", "value",
@@ -1389,16 +1581,28 @@ def StoreOp : Std_Op<"store",
}];
}
+//===----------------------------------------------------------------------===//
+// SubFOp
+//===----------------------------------------------------------------------===//
+
def SubFOp : FloatArithmeticOp<"subf"> {
let summary = "floating point subtraction operation";
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// SubIOp
+//===----------------------------------------------------------------------===//
+
def SubIOp : IntArithmeticOp<"subi"> {
let summary = "integer subtraction operation";
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// SubViewOp
+//===----------------------------------------------------------------------===//
+
def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
let summary = "memref subview operation";
let description = [{
@@ -1565,15 +1769,9 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
let hasCanonicalizer = 1;
}
-def SqrtOp : FloatUnaryOp<"sqrt"> {
- let summary = "sqrt of the specified value";
- let description = [{
- The `sqrt` operation computes the square root. It takes one operand and
- returns one result of the same type. This type may be a float scalar type, a
- vector whose element type is float, or a tensor of floats. It has no standard
- attributes.
- }];
-}
+//===----------------------------------------------------------------------===//
+// TanhOp
+//===----------------------------------------------------------------------===//
def TanhOp : FloatUnaryOp<"tanh"> {
let summary = "hyperbolic tangent of the specified value";
@@ -1585,6 +1783,10 @@ def TanhOp : FloatUnaryOp<"tanh"> {
}];
}
+//===----------------------------------------------------------------------===//
+// TensorCastOp
+//===----------------------------------------------------------------------===//
+
def TensorCastOp : CastOp<"tensor_cast"> {
let summary = "tensor cast operation";
let description = [{
@@ -1611,6 +1813,10 @@ def TensorCastOp : CastOp<"tensor_cast"> {
}];
}
+//===----------------------------------------------------------------------===//
+// TensorLoadOp
+//===----------------------------------------------------------------------===//
+
def TensorLoadOp : Std_Op<"tensor_load",
[SameOperandsAndResultShape, SameOperandsAndResultElementType,
TypesMatchWith<"result type matches tensor equivalent of 'memref'",
@@ -1648,6 +1854,10 @@ def TensorLoadOp : Std_Op<"tensor_load",
let assemblyFormat = "$memref attr-dict `:` type($memref)";
}
+//===----------------------------------------------------------------------===//
+// TensorStoreOp
+//===----------------------------------------------------------------------===//
+
def TensorStoreOp : Std_Op<"tensor_store",
[SameOperandsShape, SameOperandsElementType,
TypesMatchWith<"type of 'value' matches tensor equivalent of 'memref'",
@@ -1673,6 +1883,10 @@ def TensorStoreOp : Std_Op<"tensor_store",
let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
}
+//===----------------------------------------------------------------------===//
+// TruncateIOp
+//===----------------------------------------------------------------------===//
+
def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "integer truncation operation";
let description = [{
@@ -1705,6 +1919,45 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
}];
}
+//===----------------------------------------------------------------------===//
+// UnsignedDivIOp
+//===----------------------------------------------------------------------===//
+
+def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
+ let summary = "unsigned integer division operation";
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// UnsignedRemIOp
+//===----------------------------------------------------------------------===//
+
+def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
+ let summary = "unsigned integer division remainder operation";
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// UnsignedShiftRightOp
+//===----------------------------------------------------------------------===//
+
+def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> {
+ let summary = "unsigned integer right-shift";
+ let description = [{
+ The shift_right_unsigned operation shifts an integer value to the right by
+ a variable amount. The integer is interpreted as unsigned. The high order
+ bits are always filled with zeros.
+
+ %1 = constant 160 : i8 // %1 is 0b10100000
+ %2 = constant 3 : i8
+ %3 = shift_right_unsigned %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// ViewOp
+//===----------------------------------------------------------------------===//
+
def ViewOp : Std_Op<"view", [NoSideEffect]> {
let summary = "memref view operation";
let description = [{
@@ -1767,11 +2020,19 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> {
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// XOrOp
+//===----------------------------------------------------------------------===//
+
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
let summary = "integer binary xor";
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// ZeroExtendIOp
+//===----------------------------------------------------------------------===//
+
def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "integer zero extension operation";
let description = [{
@@ -1805,21 +2066,4 @@ def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]>
}];
}
-def AssumeAlignmentOp : Std_Op<"assume_alignment"> {
- let summary =
- "assertion that gives alignment information to the input memref";
- let description = [{
- The assume alignment operation takes a memref and a integer of alignment
- value, and internally annotates the buffer with the given alignment. If
- the buffer isn't aligned to the given alignment, the behavior is undefined.
-
- This operation doesn't affect the semantics of a correct program. It's for
- optimization only, and the optimization is best-effort.
- }];
- let arguments = (ins AnyMemRef:$memref, PositiveI32Attr:$alignment);
- let results = (outs);
-
- let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
-}
-
#endif // STANDARD_OPS
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 4f16c76fb7d9..8e7233e8acda 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -389,6 +389,68 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context);
}
+//===----------------------------------------------------------------------===//
+// AndOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
+ /// and(x, 0) -> 0
+ if (matchPattern(rhs(), m_Zero()))
+ return rhs();
+ /// and(x,x) -> x
+ if (lhs() == rhs())
+ return rhs();
+
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a & b; });
+}
+
+//===----------------------------------------------------------------------===//
+// AssumeAlignmentOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(AssumeAlignmentOp op) {
+ unsigned alignment = op.alignment().getZExtValue();
+ if (!llvm::isPowerOf2_32(alignment))
+ return op.emitOpError("alignment must be power of 2");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AtomicRMWOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(AtomicRMWOp op) {
+ if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
+ return op.emitOpError(
+ "expects the number of subscripts to be equal to memref rank");
+ switch (op.kind()) {
+ case AtomicRMWKind::addf:
+ case AtomicRMWKind::maxf:
+ case AtomicRMWKind::minf:
+ case AtomicRMWKind::mulf:
+ if (!op.value().getType().isa<FloatType>())
+ return op.emitOpError()
+ << "with kind '" << stringifyAtomicRMWKind(op.kind())
+ << "' expects a floating-point type";
+ break;
+ case AtomicRMWKind::addi:
+ case AtomicRMWKind::maxs:
+ case AtomicRMWKind::maxu:
+ case AtomicRMWKind::mins:
+ case AtomicRMWKind::minu:
+ case AtomicRMWKind::muli:
+ if (!op.value().getType().isa<IntegerType>())
+ return op.emitOpError()
+ << "with kind '" << stringifyAtomicRMWKind(op.kind())
+ << "' expects an integer type";
+ break;
+ default:
+ break;
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
@@ -1009,44 +1071,6 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return {};
}
-//===----------------------------------------------------------------------===//
-// SignedDivIOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "binary operation takes two operands");
-
- // Don't fold if it would overflow or if it requires a division by zero.
- bool overflowOrDiv0 = false;
- auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
- if (overflowOrDiv0 || !b) {
- overflowOrDiv0 = true;
- return a;
- }
- return a.sdiv_ov(b, overflowOrDiv0);
- });
- return overflowOrDiv0 ? Attribute() : result;
-}
-
-//===----------------------------------------------------------------------===//
-// UnsignedDivIOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "binary operation takes two operands");
-
- // Don't fold if it would require a division by zero.
- bool div0 = false;
- auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
- if (div0 || !b) {
- div0 = true;
- return a;
- }
- return a.udiv(b);
- });
- return div0 ? Attribute() : result;
-}
-
// ---------------------------------------------------------------------------
// DmaStartOp
// ---------------------------------------------------------------------------
@@ -1289,6 +1313,36 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+//===----------------------------------------------------------------------===//
+// FPExtOp
+//===----------------------------------------------------------------------===//
+
+bool FPExtOp::areCastCompatible(Type a, Type b) {
+ if (auto fa = a.dyn_cast<FloatType>())
+ if (auto fb = b.dyn_cast<FloatType>())
+ return fa.getWidth() < fb.getWidth();
+ if (auto va = a.dyn_cast<VectorType>())
+ if (auto vb = b.dyn_cast<VectorType>())
+ return va.getShape().equals(vb.getShape()) &&
+ areCastCompatible(va.getElementType(), vb.getElementType());
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// FPTruncOp
+//===----------------------------------------------------------------------===//
+
+bool FPTruncOp::areCastCompatible(Type a, Type b) {
+ if (auto fa = a.dyn_cast<FloatType>())
+ if (auto fb = b.dyn_cast<FloatType>())
+ return fa.getWidth() > fb.getWidth();
+ if (auto va = a.dyn_cast<VectorType>())
+ if (auto vb = b.dyn_cast<VectorType>())
+ return va.getShape().equals(vb.getShape()) &&
+ areCastCompatible(va.getElementType(), vb.getElementType());
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
@@ -1435,6 +1489,22 @@ OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
[](APInt a, APInt b) { return a * b; });
}
+//===----------------------------------------------------------------------===//
+// OrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
+ /// or(x, 0) -> x
+ if (matchPattern(rhs(), m_Zero()))
+ return lhs();
+ /// or(x,x) -> x
+ if (lhs() == rhs())
+ return rhs();
+
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a | b; });
+}
+
//===----------------------------------------------------------------------===//
// PrefetchOp
//===----------------------------------------------------------------------===//
@@ -1517,58 +1587,6 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
return IntegerAttr();
}
-//===----------------------------------------------------------------------===//
-// SignedRemIOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "remi_signed takes two operands");
-
- auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
- if (!rhs)
- return {};
- auto rhsValue = rhs.getValue();
-
- // x % 1 = 0
- if (rhsValue.isOneValue())
- return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
-
- // Don't fold if it requires division by zero.
- if (rhsValue.isNullValue())
- return {};
-
- auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
- if (!lhs)
- return {};
- return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
-}
-
-//===----------------------------------------------------------------------===//
-// UnsignedRemIOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "remi_unsigned takes two operands");
-
- auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
- if (!rhs)
- return {};
- auto rhsValue = rhs.getValue();
-
- // x % 1 = 0
- if (rhsValue.isOneValue())
- return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
-
- // Don't fold if it requires division by zero.
- if (rhsValue.isNullValue())
- return {};
-
- auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
- if (!lhs)
- return {};
- return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
-}
-
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
@@ -1593,15 +1611,6 @@ static LogicalResult verify(ReturnOp op) {
return success();
}
-//===----------------------------------------------------------------------===//
-// SIToFPOp
-//===----------------------------------------------------------------------===//
-
-// sitofp is applicable from integer types to float types.
-bool SIToFPOp::areCastCompatible(Type a, Type b) {
- return a.isSignlessInteger() && b.isa<FloatType>();
-}
-
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
@@ -1644,37 +1653,91 @@ static LogicalResult verify(SignExtendIOp op) {
}
//===----------------------------------------------------------------------===//
-// SplatOp
+// SignedDivIOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(SplatOp op) {
- // TODO: we could replace this by a trait.
- if (op.getOperand().getType() !=
- op.getType().cast<ShapedType>().getElementType())
- return op.emitError("operand should be of elemental type of result type");
+OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "binary operation takes two operands");
- return success();
+ // Don't fold if it would overflow or if it requires a division by zero.
+ bool overflowOrDiv0 = false;
+ auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
+ if (overflowOrDiv0 || !b) {
+ overflowOrDiv0 = true;
+ return a;
+ }
+ return a.sdiv_ov(b, overflowOrDiv0);
+ });
+ return overflowOrDiv0 ? Attribute() : result;
}
-// Constant folding hook for SplatOp.
-OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 1 && "splat takes one operand");
+//===----------------------------------------------------------------------===//
+// SignedRemIOp
+//===----------------------------------------------------------------------===//
- auto constOperand = operands.front();
- if (!constOperand ||
- (!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>()))
+OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "remi_signed takes two operands");
+
+ auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
+ if (!rhs)
return {};
+ auto rhsValue = rhs.getValue();
- auto shapedType = getType().cast<ShapedType>();
- assert(shapedType.getElementType() == constOperand.getType() &&
- "incorrect input attribute type for folding");
+ // x % 1 = 0
+ if (rhsValue.isOneValue())
+ return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
- // SplatElementsAttr::get treats single value for second arg as being a splat.
- return SplatElementsAttr::get(shapedType, {constOperand});
-}
+ // Don't fold if it requires division by zero.
+ if (rhsValue.isNullValue())
+ return {};
-//===----------------------------------------------------------------------===//
-// StoreOp
+ auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+ if (!lhs)
+ return {};
+ return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
+}
+
+//===----------------------------------------------------------------------===//
+// SIToFPOp
+//===----------------------------------------------------------------------===//
+
+// sitofp is applicable from integer types to float types.
+bool SIToFPOp::areCastCompatible(Type a, Type b) {
+ return a.isSignlessInteger() && b.isa<FloatType>();
+}
+
+//===----------------------------------------------------------------------===//
+// SplatOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(SplatOp op) {
+ // TODO: we could replace this by a trait.
+ if (op.getOperand().getType() !=
+ op.getType().cast<ShapedType>().getElementType())
+ return op.emitError("operand should be of elemental type of result type");
+
+ return success();
+}
+
+// Constant folding hook for SplatOp.
+OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 1 && "splat takes one operand");
+
+ auto constOperand = operands.front();
+ if (!constOperand ||
+ (!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>()))
+ return {};
+
+ auto shapedType = getType().cast<ShapedType>();
+ assert(shapedType.getElementType() == constOperand.getType() &&
+ "incorrect input attribute type for folding");
+
+ // SplatElementsAttr::get treats single value for second arg as being a splat.
+ return SplatElementsAttr::get(shapedType, {constOperand});
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(StoreOp op) {
@@ -1713,749 +1776,751 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
}
//===----------------------------------------------------------------------===//
-// AndOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
- /// and(x, 0) -> 0
- if (matchPattern(rhs(), m_Zero()))
- return rhs();
- /// and(x,x) -> x
- if (lhs() == rhs())
- return rhs();
-
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a & b; });
-}
-
-//===----------------------------------------------------------------------===//
-// OrOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
- /// or(x, 0) -> x
- if (matchPattern(rhs(), m_Zero()))
- return lhs();
- /// or(x,x) -> x
- if (lhs() == rhs())
- return rhs();
-
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a | b; });
-}
-
-//===----------------------------------------------------------------------===//
-// XOrOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
- /// xor(x, 0) -> x
- if (matchPattern(rhs(), m_Zero()))
- return lhs();
- /// xor(x,x) -> 0
- if (lhs() == rhs())
- return Builder(getContext()).getZeroAttr(getType());
-
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a ^ b; });
-}
-
-//===----------------------------------------------------------------------===//
-// TensorCastOp
+// SubViewOp
//===----------------------------------------------------------------------===//
-bool TensorCastOp::areCastCompatible(Type a, Type b) {
- auto aT = a.dyn_cast<TensorType>();
- auto bT = b.dyn_cast<TensorType>();
- if (!aT || !bT)
- return false;
-
- if (aT.getElementType() != bT.getElementType())
- return false;
-
- return succeeded(verifyCompatibleShape(aT, bT));
-}
+// Returns a MemRefType with dynamic sizes and offset and the same stride as the
+// `memRefType` passed as argument.
+// TODO(andydavis,ntv) Evolve to a more powerful inference that can also keep
+// sizes and offset static.
+static Type inferSubViewResultType(MemRefType memRefType) {
+ auto rank = memRefType.getRank();
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto res = getStridesAndOffset(memRefType, strides, offset);
+ assert(succeeded(res) && "SubViewOp expected strided memref type");
+ (void)res;
-OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
- return impl::foldCastOp(*this);
+ // Assume sizes and offset are fully dynamic for now until canonicalization
+ // occurs on the ranges. Typed strides don't change though.
+ offset = MemRefType::getDynamicStrideOrOffset();
+ // Overwrite strides because verifier will not pass.
+ // TODO(b/144419106): don't force degrade the strides to fully dynamic.
+ for (auto &stride : strides)
+ stride = MemRefType::getDynamicStrideOrOffset();
+ auto stridedLayout =
+ makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
+ SmallVector<int64_t, 4> sizes(rank, ShapedType::kDynamicSize);
+ return MemRefType::Builder(memRefType)
+ .setShape(sizes)
+ .setAffineMaps(stridedLayout);
}
-//===----------------------------------------------------------------------===//
-// Helpers for Tensor[Load|Store]Op
-//===----------------------------------------------------------------------===//
-
-static Type getTensorTypeFromMemRefType(Type type) {
- if (auto memref = type.dyn_cast<MemRefType>())
- return RankedTensorType::get(memref.getShape(), memref.getElementType());
- return NoneType::get(type.getContext());
+void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source,
+ ValueRange offsets, ValueRange sizes,
+ ValueRange strides, Type resultType,
+ ArrayRef<NamedAttribute> attrs) {
+ if (!resultType)
+ resultType = inferSubViewResultType(source.getType().cast<MemRefType>());
+ auto segmentAttr = b->getI32VectorAttr(
+ {1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()),
+ static_cast<int32_t>(strides.size())});
+ build(b, result, resultType, source, offsets, sizes, strides, segmentAttr);
+ result.addAttributes(attrs);
}
-//===----------------------------------------------------------------------===//
-// TruncateIOp
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verify(TruncateIOp op) {
- auto srcType = getElementTypeOrSelf(op.getOperand().getType());
- auto dstType = getElementTypeOrSelf(op.getType());
-
- if (srcType.isa<IndexType>())
- return op.emitError() << srcType << " is not a valid operand type";
- if (dstType.isa<IndexType>())
- return op.emitError() << dstType << " is not a valid result type";
-
- if (srcType.cast<IntegerType>().getWidth() <=
- dstType.cast<IntegerType>().getWidth())
- return op.emitError("operand type ")
- << srcType << " must be wider than result type " << dstType;
-
- return success();
+void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
+ Value source) {
+ build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{},
+ resultType);
}
-//===----------------------------------------------------------------------===//
-// ViewOp
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
+static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType srcInfo;
- SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
+ SmallVector<OpAsmParser::OperandType, 4> offsetsInfo;
SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
+ SmallVector<OpAsmParser::OperandType, 4> stridesInfo;
auto indexType = parser.getBuilder().getIndexType();
Type srcType, dstType;
- llvm::SMLoc offsetLoc;
- if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
- parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
+ if (parser.parseOperand(srcInfo) ||
+ parser.parseOperandList(offsetsInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) {
return failure();
+ }
- if (offsetInfo.size() > 1)
- return parser.emitError(offsetLoc) << "expects 0 or 1 offset operand";
+ auto builder = parser.getBuilder();
+ result.addAttribute(
+ SubViewOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr({1, static_cast<int>(offsetsInfo.size()),
+ static_cast<int32_t>(sizesInfo.size()),
+ static_cast<int32_t>(stridesInfo.size())}));
return failure(
- parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||
parser.resolveOperand(srcInfo, srcType, result.operands) ||
- parser.resolveOperands(offsetInfo, indexType, result.operands) ||
+ parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
parser.resolveOperands(sizesInfo, indexType, result.operands) ||
+ parser.resolveOperands(stridesInfo, indexType, result.operands) ||
parser.parseKeywordType("to", dstType) ||
parser.addTypeToList(dstType, result.types));
}
-static void print(OpAsmPrinter &p, ViewOp op) {
- p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
- auto dynamicOffset = op.getDynamicOffset();
- if (dynamicOffset != nullptr)
- p.printOperand(dynamicOffset);
- p << "][" << op.getDynamicSizes() << ']';
- p.printOptionalAttrDict(op.getAttrs());
+static void print(OpAsmPrinter &p, SubViewOp op) {
+ p << op.getOperationName() << ' ' << op.getOperand(0) << '[' << op.offsets()
+ << "][" << op.sizes() << "][" << op.strides() << ']';
+
+ std::array<StringRef, 1> elidedAttrs = {
+ SubViewOp::getOperandSegmentSizeAttr()};
+ p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
p << " : " << op.getOperand(0).getType() << " to " << op.getType();
}
-Value ViewOp::getDynamicOffset() {
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto result =
- succeeded(mlir::getStridesAndOffset(getType(), strides, offset));
- assert(result);
- if (result && offset == MemRefType::getDynamicStrideOrOffset())
- return getOperand(1);
- return nullptr;
-}
+static LogicalResult verify(SubViewOp op) {
+ auto baseType = op.getBaseMemRefType().cast<MemRefType>();
+ auto subViewType = op.getType();
-static LogicalResult verifyDynamicStrides(MemRefType memrefType,
- ArrayRef<int64_t> strides) {
- ArrayRef<int64_t> shape = memrefType.getShape();
- unsigned rank = memrefType.getRank();
- assert(rank == strides.size());
- bool dynamicStrides = false;
- for (int i = rank - 2; i >= 0; --i) {
- // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag.
- if (ShapedType::isDynamic(shape[i + 1]))
- dynamicStrides = true;
- // If stride at dim 'i' is not dynamic, return error.
- if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset())
- return failure();
+ // The rank of the base and result subview must match.
+ if (baseType.getRank() != subViewType.getRank()) {
+ return op.emitError(
+ "expected rank of result type to match rank of base type ");
}
- return success();
-}
-
-static LogicalResult verify(ViewOp op) {
- auto baseType = op.getOperand(0).getType().cast<MemRefType>();
- auto viewType = op.getResult().getType().cast<MemRefType>();
-
- // The base memref should have identity layout map (or none).
- if (baseType.getAffineMaps().size() > 1 ||
- (baseType.getAffineMaps().size() == 1 &&
- !baseType.getAffineMaps()[0].isIdentity()))
- return op.emitError("unsupported map for base memref type ") << baseType;
// The base memref and the view memref should be in the same memory space.
- if (baseType.getMemorySpace() != viewType.getMemorySpace())
+ if (baseType.getMemorySpace() != subViewType.getMemorySpace())
return op.emitError("
diff erent memory spaces specified for base memref "
"type ")
- << baseType << " and view memref type " << viewType;
+ << baseType << " and subview memref type " << subViewType;
+
+ // Verify that the base memref type has a strided layout map.
+ int64_t baseOffset;
+ SmallVector<int64_t, 4> baseStrides;
+ if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset)))
+ return op.emitError("base type ") << subViewType << " is not strided";
// Verify that the result memref type has a strided layout map.
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(viewType, strides, offset)))
- return op.emitError("result type ") << viewType << " is not strided";
+ int64_t subViewOffset;
+ SmallVector<int64_t, 4> subViewStrides;
+ if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset)))
+ return op.emitError("result type ") << subViewType << " is not strided";
- // Verify that we have the correct number of operands for the result type.
- unsigned memrefOperandCount = 1;
- unsigned numDynamicDims = viewType.getNumDynamicDims();
- unsigned dynamicOffsetCount =
- offset == MemRefType::getDynamicStrideOrOffset() ? 1 : 0;
- if (op.getNumOperands() !=
- memrefOperandCount + numDynamicDims + dynamicOffsetCount)
- return op.emitError("incorrect number of operands for type ") << viewType;
+ // Num offsets should either be zero or rank of memref.
+ if (op.getNumOffsets() != 0 && op.getNumOffsets() != subViewType.getRank()) {
+ return op.emitError("expected number of dynamic offsets specified to match "
+ "the rank of the result type ")
+ << subViewType;
+ }
- // Verify dynamic strides symbols were added to correct dimensions based
- // on dynamic sizes.
- if (failed(verifyDynamicStrides(viewType, strides)))
- return op.emitError("incorrect dynamic strides in view memref type ")
- << viewType;
+ // Num sizes should either be zero or rank of memref.
+ if (op.getNumSizes() != 0 && op.getNumSizes() != subViewType.getRank()) {
+ return op.emitError("expected number of dynamic sizes specified to match "
+ "the rank of the result type ")
+ << subViewType;
+ }
+
+ // Num strides should either be zero or rank of memref.
+ if (op.getNumStrides() != 0 && op.getNumStrides() != subViewType.getRank()) {
+ return op.emitError("expected number of dynamic strides specified to match "
+ "the rank of the result type ")
+ << subViewType;
+ }
+
+ // Verify that if the shape of the subview type is static, then sizes are not
+ // dynamic values, and vice versa.
+ if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) ||
+ (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) {
+ return op.emitError("invalid to specify dynamic sizes when subview result "
+ "type is statically shaped and viceversa");
+ }
+
+ // Verify that if dynamic sizes are specified, then the result memref type
+ // have full dynamic dimensions.
+ if (op.getNumSizes() > 0) {
+ if (llvm::any_of(subViewType.getShape(), [](int64_t dim) {
+ return dim != ShapedType::kDynamicSize;
+ })) {
+ // TODO: This is based on the assumption that number of size arguments are
+ // either 0, or the rank of the result type. It is possible to have more
+ // fine-grained verification where only particular dimensions are
+ // dynamic. That probably needs further changes to the shape op
+ // specification.
+ return op.emitError("expected shape of result type to be fully dynamic "
+ "when sizes are specified");
+ }
+ }
+
+ // Verify that if dynamic offsets are specified or base memref has dynamic
+ // offset or base memref has dynamic strides, then the subview offset is
+ // dynamic.
+ if ((op.getNumOffsets() > 0 ||
+ baseOffset == MemRefType::getDynamicStrideOrOffset() ||
+ llvm::is_contained(baseStrides,
+ MemRefType::getDynamicStrideOrOffset())) &&
+ subViewOffset != MemRefType::getDynamicStrideOrOffset()) {
+ return op.emitError(
+ "expected result memref layout map to have dynamic offset");
+ }
+
+ // For now, verify that if dynamic strides are specified, then all the result
+ // memref type have dynamic strides.
+ if (op.getNumStrides() > 0) {
+ if (llvm::any_of(subViewStrides, [](int64_t stride) {
+ return stride != MemRefType::getDynamicStrideOrOffset();
+ })) {
+ return op.emitError("expected result type to have dynamic strides");
+ }
+ }
+
+ // If any of the base memref has dynamic stride, then the corresponding
+ // stride of the subview must also have dynamic stride.
+ assert(baseStrides.size() == subViewStrides.size());
+ for (auto stride : enumerate(baseStrides)) {
+ if (stride.value() == MemRefType::getDynamicStrideOrOffset() &&
+ subViewStrides[stride.index()] !=
+ MemRefType::getDynamicStrideOrOffset()) {
+ return op.emitError(
+ "expected result type to have dynamic stride along a dimension if "
+ "the base memref type has dynamic stride along that dimension");
+ }
+ }
return success();
}
-namespace {
-
-struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
- using OpRewritePattern<ViewOp>::OpRewritePattern;
+raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
+ return os << "range " << range.offset << ":" << range.size << ":"
+ << range.stride;
+}
- PatternMatchResult matchAndRewrite(ViewOp viewOp,
- PatternRewriter &rewriter) const override {
- // Return if none of the operands are constants.
- if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
- return matchPattern(operand, m_ConstantIndex());
- }))
- return matchFailure();
+SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
+ SmallVector<Range, 8> res;
+ unsigned rank = getType().getRank();
+ res.reserve(rank);
+ for (unsigned i = 0; i < rank; ++i)
+ res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
+ *(strides().begin() + i)});
+ return res;
+}
- // Get result memref type.
- auto memrefType = viewOp.getType();
- if (memrefType.getAffineMaps().size() > 1)
- return matchFailure();
- auto map = memrefType.getAffineMaps().empty()
- ? AffineMap::getMultiDimIdentityMap(memrefType.getRank(),
- rewriter.getContext())
- : memrefType.getAffineMaps()[0];
+LogicalResult
+SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
+ // If the strides are dynamic return failure.
+ if (getNumStrides())
+ return failure();
- // Get offset from old memref view type 'memRefType'.
- int64_t oldOffset;
- SmallVector<int64_t, 4> oldStrides;
- if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
- return matchFailure();
+ // When static, the stride operands can be retrieved by taking the strides of
+ // the result of the subview op, and dividing the strides of the base memref.
+ int64_t resultOffset, baseOffset;
+ SmallVector<int64_t, 2> resultStrides, baseStrides;
+ if (failed(
+ getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) ||
+ llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) ||
+ failed(getStridesAndOffset(getType(), resultStrides, resultOffset)))
+ return failure();
- SmallVector<Value, 4> newOperands;
+ assert(static_cast<int64_t>(resultStrides.size()) == getType().getRank() &&
+ baseStrides.size() == resultStrides.size() &&
+ "base and result memrefs must have the same rank");
+ assert(!llvm::is_contained(resultStrides,
+ MemRefType::getDynamicStrideOrOffset()) &&
+ "strides of subview op must be static, when there are no dynamic "
+ "strides specified");
+ staticStrides.resize(getType().getRank());
+ for (auto resultStride : enumerate(resultStrides)) {
+ auto baseStride = baseStrides[resultStride.index()];
+ // The result stride is expected to be a multiple of the base stride. Abort
+ // if that is not the case.
+ if (resultStride.value() < baseStride ||
+ resultStride.value() % baseStride != 0)
+ return failure();
+ staticStrides[resultStride.index()] = resultStride.value() / baseStride;
+ }
+ return success();
+}
- // Fold dynamic offset operand if it is produced by a constant.
- auto dynamicOffset = viewOp.getDynamicOffset();
- int64_t newOffset = oldOffset;
- unsigned dynamicOffsetOperandCount = 0;
- if (dynamicOffset != nullptr) {
- auto *defOp = dynamicOffset.getDefiningOp();
- if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
- // Dynamic offset will be folded into the map.
- newOffset = constantIndexOp.getValue();
- } else {
- // Unable to fold dynamic offset. Add it to 'newOperands' list.
- newOperands.push_back(dynamicOffset);
- dynamicOffsetOperandCount = 1;
- }
- }
+namespace {
- // Fold any dynamic dim operands which are produced by a constant.
- SmallVector<int64_t, 4> newShapeConstants;
- newShapeConstants.reserve(memrefType.getRank());
+/// Pattern to rewrite a subview op with constant size arguments.
+class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> {
+public:
+ using OpRewritePattern<SubViewOp>::OpRewritePattern;
- unsigned dynamicDimPos = viewOp.getDynamicSizesOperandStart();
- unsigned rank = memrefType.getRank();
- for (unsigned dim = 0, e = rank; dim < e; ++dim) {
- int64_t dimSize = memrefType.getDimSize(dim);
- // If this is already static dimension, keep it.
- if (!ShapedType::isDynamic(dimSize)) {
- newShapeConstants.push_back(dimSize);
- continue;
- }
- auto *defOp = viewOp.getOperand(dynamicDimPos).getDefiningOp();
- if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
- // Dynamic shape dimension will be folded.
- newShapeConstants.push_back(constantIndexOp.getValue());
- } else {
- // Dynamic shape dimension not folded; copy operand from old memref.
- newShapeConstants.push_back(dimSize);
- newOperands.push_back(viewOp.getOperand(dynamicDimPos));
- }
- dynamicDimPos++;
+ PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
+ PatternRewriter &rewriter) const override {
+ MemRefType subViewType = subViewOp.getType();
+ // Follow all or nothing approach for shapes for now. If all the operands
+ // for sizes are constants then fold it into the type of the result memref.
+ if (subViewType.hasStaticShape() ||
+ llvm::any_of(subViewOp.sizes(), [](Value operand) {
+ return !matchPattern(operand, m_ConstantIndex());
+ })) {
+ return matchFailure();
}
-
- // Compute new strides based on 'newShapeConstants'.
- SmallVector<int64_t, 4> newStrides(rank);
- newStrides[rank - 1] = 1;
- bool dynamicStrides = false;
- for (int i = rank - 2; i >= 0; --i) {
- if (ShapedType::isDynamic(newShapeConstants[i + 1]))
- dynamicStrides = true;
- if (dynamicStrides)
- newStrides[i] = MemRefType::getDynamicStrideOrOffset();
- else
- newStrides[i] = newShapeConstants[i + 1] * newStrides[i + 1];
+ SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
+ for (auto size : llvm::enumerate(subViewOp.sizes())) {
+ auto defOp = size.value().getDefiningOp();
+ assert(defOp);
+ staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
}
+ MemRefType newMemRefType =
+ MemRefType::Builder(subViewType).setShape(staticShape);
+ auto newSubViewOp = rewriter.create<SubViewOp>(
+ subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
+ ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
+ // Insert a memref_cast for compatibility of the uses of the op.
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
+ subViewOp.getType());
+ return matchSuccess();
+ }
+};
- // Regenerate strided layout map with 'newStrides' and 'newOffset'.
- map = makeStridedLinearLayoutMap(newStrides, newOffset,
- rewriter.getContext());
+// Pattern to rewrite a subview op with constant stride arguments.
+class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> {
+public:
+ using OpRewritePattern<SubViewOp>::OpRewritePattern;
- // Create new memref type with constant folded dims and/or offset/strides.
- MemRefType newMemRefType = MemRefType::Builder(memrefType)
- .setShape(newShapeConstants)
- .setAffineMaps({map});
- (void)dynamicOffsetOperandCount; // unused in opt mode
- assert(static_cast<int64_t>(newOperands.size()) ==
- dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims());
+ PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
+ PatternRewriter &rewriter) const override {
+ if (subViewOp.getNumStrides() == 0) {
+ return matchFailure();
+ }
+ // Follow all or nothing approach for strides for now. If all the operands
+ // for strides are constants then fold it into the strides of the result
+ // memref.
+ int64_t baseOffset, resultOffset;
+ SmallVector<int64_t, 4> baseStrides, resultStrides;
+ MemRefType subViewType = subViewOp.getType();
+ if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides,
+ baseOffset)) ||
+ failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
+ llvm::is_contained(baseStrides,
+ MemRefType::getDynamicStrideOrOffset()) ||
+ llvm::any_of(subViewOp.strides(), [](Value stride) {
+ return !matchPattern(stride, m_ConstantIndex());
+ })) {
+ return matchFailure();
+ }
- // Create new ViewOp.
- auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
- viewOp.getOperand(0), newOperands);
- // Insert a cast so we have the same type as the old memref type.
- rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
- viewOp.getType());
+ SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
+ for (auto stride : llvm::enumerate(subViewOp.strides())) {
+ auto defOp = stride.value().getDefiningOp();
+ assert(defOp);
+ assert(baseStrides[stride.index()] > 0);
+ staticStrides[stride.index()] =
+ cast<ConstantIndexOp>(defOp).getValue() * baseStrides[stride.index()];
+ }
+ AffineMap layoutMap = makeStridedLinearLayoutMap(
+ staticStrides, resultOffset, rewriter.getContext());
+ MemRefType newMemRefType =
+ MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
+ auto newSubViewOp = rewriter.create<SubViewOp>(
+ subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
+ subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
+ // Insert a memref_cast for compatibility of the uses of the op.
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
+ subViewOp.getType());
return matchSuccess();
}
};
-struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
- using OpRewritePattern<ViewOp>::OpRewritePattern;
+// Pattern to rewrite a subview op with constant offset arguments.
+class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
+public:
+ using OpRewritePattern<SubViewOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(ViewOp viewOp,
+ PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
- Value memrefOperand = viewOp.getOperand(0);
- MemRefCastOp memrefCastOp =
- dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp());
- if (!memrefCastOp)
+ if (subViewOp.getNumOffsets() == 0) {
return matchFailure();
- Value allocOperand = memrefCastOp.getOperand();
- AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
- if (!allocOp)
+ }
+ // Follow all or nothing approach for offsets for now. If all the operands
+ // for offsets are constants then fold it into the offset of the result
+ // memref.
+ int64_t baseOffset, resultOffset;
+ SmallVector<int64_t, 4> baseStrides, resultStrides;
+ MemRefType subViewType = subViewOp.getType();
+ if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides,
+ baseOffset)) ||
+ failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
+ llvm::is_contained(baseStrides,
+ MemRefType::getDynamicStrideOrOffset()) ||
+ baseOffset == MemRefType::getDynamicStrideOrOffset() ||
+ llvm::any_of(subViewOp.offsets(), [](Value stride) {
+ return !matchPattern(stride, m_ConstantIndex());
+ })) {
return matchFailure();
- rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
- viewOp.operands());
+ }
+
+ auto staticOffset = baseOffset;
+ for (auto offset : llvm::enumerate(subViewOp.offsets())) {
+ auto defOp = offset.value().getDefiningOp();
+ assert(defOp);
+ assert(baseStrides[offset.index()] > 0);
+ staticOffset +=
+ cast<ConstantIndexOp>(defOp).getValue() * baseStrides[offset.index()];
+ }
+
+ AffineMap layoutMap = makeStridedLinearLayoutMap(
+ resultStrides, staticOffset, rewriter.getContext());
+ MemRefType newMemRefType =
+ MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
+ auto newSubViewOp = rewriter.create<SubViewOp>(
+ subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
+ subViewOp.sizes(), subViewOp.strides(), newMemRefType);
+ // Insert a memref_cast for compatibility of the uses of the op.
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
+ subViewOp.getType());
return matchSuccess();
}
};
} // end anonymous namespace
-void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
- results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
+void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder,
+ SubViewOpOffsetFolder>(context);
}
//===----------------------------------------------------------------------===//
-// SubViewOp
+// TensorCastOp
//===----------------------------------------------------------------------===//
-// Returns a MemRefType with dynamic sizes and offset and the same stride as the
-// `memRefType` passed as argument.
-// TODO(andydavis,ntv) Evolve to a more powerful inference that can also keep
-// sizes and offset static.
-static Type inferSubViewResultType(MemRefType memRefType) {
- auto rank = memRefType.getRank();
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto res = getStridesAndOffset(memRefType, strides, offset);
- assert(succeeded(res) && "SubViewOp expected strided memref type");
- (void)res;
-
- // Assume sizes and offset are fully dynamic for now until canonicalization
- // occurs on the ranges. Typed strides don't change though.
- offset = MemRefType::getDynamicStrideOrOffset();
- // Overwrite strides because verifier will not pass.
- // TODO(b/144419106): don't force degrade the strides to fully dynamic.
- for (auto &stride : strides)
- stride = MemRefType::getDynamicStrideOrOffset();
- auto stridedLayout =
- makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
- SmallVector<int64_t, 4> sizes(rank, ShapedType::kDynamicSize);
- return MemRefType::Builder(memRefType)
- .setShape(sizes)
- .setAffineMaps(stridedLayout);
-}
+bool TensorCastOp::areCastCompatible(Type a, Type b) {
+ auto aT = a.dyn_cast<TensorType>();
+ auto bT = b.dyn_cast<TensorType>();
+ if (!aT || !bT)
+ return false;
-void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source,
- ValueRange offsets, ValueRange sizes,
- ValueRange strides, Type resultType,
- ArrayRef<NamedAttribute> attrs) {
- if (!resultType)
- resultType = inferSubViewResultType(source.getType().cast<MemRefType>());
- auto segmentAttr = b->getI32VectorAttr(
- {1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()),
- static_cast<int32_t>(strides.size())});
- build(b, result, resultType, source, offsets, sizes, strides, segmentAttr);
- result.addAttributes(attrs);
-}
+ if (aT.getElementType() != bT.getElementType())
+ return false;
-void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
- Value source) {
- build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{},
- resultType);
+ return succeeded(verifyCompatibleShape(aT, bT));
}
-static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::OperandType srcInfo;
- SmallVector<OpAsmParser::OperandType, 4> offsetsInfo;
- SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
- SmallVector<OpAsmParser::OperandType, 4> stridesInfo;
- auto indexType = parser.getBuilder().getIndexType();
- Type srcType, dstType;
- if (parser.parseOperand(srcInfo) ||
- parser.parseOperandList(offsetsInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) {
- return failure();
- }
-
- auto builder = parser.getBuilder();
- result.addAttribute(
- SubViewOp::getOperandSegmentSizeAttr(),
- builder.getI32VectorAttr({1, static_cast<int>(offsetsInfo.size()),
- static_cast<int32_t>(sizesInfo.size()),
- static_cast<int32_t>(stridesInfo.size())}));
-
- return failure(
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(srcType) ||
- parser.resolveOperand(srcInfo, srcType, result.operands) ||
- parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
- parser.resolveOperands(sizesInfo, indexType, result.operands) ||
- parser.resolveOperands(stridesInfo, indexType, result.operands) ||
- parser.parseKeywordType("to", dstType) ||
- parser.addTypeToList(dstType, result.types));
+OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
+ return impl::foldCastOp(*this);
}
-static void print(OpAsmPrinter &p, SubViewOp op) {
- p << op.getOperationName() << ' ' << op.getOperand(0) << '[' << op.offsets()
- << "][" << op.sizes() << "][" << op.strides() << ']';
+//===----------------------------------------------------------------------===//
+// Helpers for Tensor[Load|Store]Op
+//===----------------------------------------------------------------------===//
- std::array<StringRef, 1> elidedAttrs = {
- SubViewOp::getOperandSegmentSizeAttr()};
- p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
- p << " : " << op.getOperand(0).getType() << " to " << op.getType();
+static Type getTensorTypeFromMemRefType(Type type) {
+ if (auto memref = type.dyn_cast<MemRefType>())
+ return RankedTensorType::get(memref.getShape(), memref.getElementType());
+ return NoneType::get(type.getContext());
}
-static LogicalResult verify(SubViewOp op) {
- auto baseType = op.getBaseMemRefType().cast<MemRefType>();
- auto subViewType = op.getType();
-
- // The rank of the base and result subview must match.
- if (baseType.getRank() != subViewType.getRank()) {
- return op.emitError(
- "expected rank of result type to match rank of base type ");
- }
-
- // The base memref and the view memref should be in the same memory space.
- if (baseType.getMemorySpace() != subViewType.getMemorySpace())
- return op.emitError("
diff erent memory spaces specified for base memref "
- "type ")
- << baseType << " and subview memref type " << subViewType;
-
- // Verify that the base memref type has a strided layout map.
- int64_t baseOffset;
- SmallVector<int64_t, 4> baseStrides;
- if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset)))
- return op.emitError("base type ") << subViewType << " is not strided";
-
- // Verify that the result memref type has a strided layout map.
- int64_t subViewOffset;
- SmallVector<int64_t, 4> subViewStrides;
- if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset)))
- return op.emitError("result type ") << subViewType << " is not strided";
-
- // Num offsets should either be zero or rank of memref.
- if (op.getNumOffsets() != 0 && op.getNumOffsets() != subViewType.getRank()) {
- return op.emitError("expected number of dynamic offsets specified to match "
- "the rank of the result type ")
- << subViewType;
- }
-
- // Num sizes should either be zero or rank of memref.
- if (op.getNumSizes() != 0 && op.getNumSizes() != subViewType.getRank()) {
- return op.emitError("expected number of dynamic sizes specified to match "
- "the rank of the result type ")
- << subViewType;
- }
-
- // Num strides should either be zero or rank of memref.
- if (op.getNumStrides() != 0 && op.getNumStrides() != subViewType.getRank()) {
- return op.emitError("expected number of dynamic strides specified to match "
- "the rank of the result type ")
- << subViewType;
- }
-
- // Verify that if the shape of the subview type is static, then sizes are not
- // dynamic values, and vice versa.
- if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) ||
- (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) {
- return op.emitError("invalid to specify dynamic sizes when subview result "
- "type is statically shaped and viceversa");
- }
+//===----------------------------------------------------------------------===//
+// TruncateIOp
+//===----------------------------------------------------------------------===//
- // Verify that if dynamic sizes are specified, then the result memref type
- // have full dynamic dimensions.
- if (op.getNumSizes() > 0) {
- if (llvm::any_of(subViewType.getShape(), [](int64_t dim) {
- return dim != ShapedType::kDynamicSize;
- })) {
- // TODO: This is based on the assumption that number of size arguments are
- // either 0, or the rank of the result type. It is possible to have more
- // fine-grained verification where only particular dimensions are
- // dynamic. That probably needs further changes to the shape op
- // specification.
- return op.emitError("expected shape of result type to be fully dynamic "
- "when sizes are specified");
- }
- }
+static LogicalResult verify(TruncateIOp op) {
+ auto srcType = getElementTypeOrSelf(op.getOperand().getType());
+ auto dstType = getElementTypeOrSelf(op.getType());
- // Verify that if dynamic offsets are specified or base memref has dynamic
- // offset or base memref has dynamic strides, then the subview offset is
- // dynamic.
- if ((op.getNumOffsets() > 0 ||
- baseOffset == MemRefType::getDynamicStrideOrOffset() ||
- llvm::is_contained(baseStrides,
- MemRefType::getDynamicStrideOrOffset())) &&
- subViewOffset != MemRefType::getDynamicStrideOrOffset()) {
- return op.emitError(
- "expected result memref layout map to have dynamic offset");
- }
+ if (srcType.isa<IndexType>())
+ return op.emitError() << srcType << " is not a valid operand type";
+ if (dstType.isa<IndexType>())
+ return op.emitError() << dstType << " is not a valid result type";
- // For now, verify that if dynamic strides are specified, then all the result
- // memref type have dynamic strides.
- if (op.getNumStrides() > 0) {
- if (llvm::any_of(subViewStrides, [](int64_t stride) {
- return stride != MemRefType::getDynamicStrideOrOffset();
- })) {
- return op.emitError("expected result type to have dynamic strides");
- }
- }
+ if (srcType.cast<IntegerType>().getWidth() <=
+ dstType.cast<IntegerType>().getWidth())
+ return op.emitError("operand type ")
+ << srcType << " must be wider than result type " << dstType;
- // If any of the base memref has dynamic stride, then the corresponding
- // stride of the subview must also have dynamic stride.
- assert(baseStrides.size() == subViewStrides.size());
- for (auto stride : enumerate(baseStrides)) {
- if (stride.value() == MemRefType::getDynamicStrideOrOffset() &&
- subViewStrides[stride.index()] !=
- MemRefType::getDynamicStrideOrOffset()) {
- return op.emitError(
- "expected result type to have dynamic stride along a dimension if "
- "the base memref type has dynamic stride along that dimension");
- }
- }
return success();
}
-raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
- return os << "range " << range.offset << ":" << range.size << ":"
- << range.stride;
+//===----------------------------------------------------------------------===//
+// UnsignedDivIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "binary operation takes two operands");
+
+ // Don't fold if it would require a division by zero.
+ bool div0 = false;
+ auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
+ if (div0 || !b) {
+ div0 = true;
+ return a;
+ }
+ return a.udiv(b);
+ });
+ return div0 ? Attribute() : result;
}
-SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
- SmallVector<Range, 8> res;
- unsigned rank = getType().getRank();
- res.reserve(rank);
- for (unsigned i = 0; i < rank; ++i)
- res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
- *(strides().begin() + i)});
- return res;
+//===----------------------------------------------------------------------===//
+// UnsignedRemIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "remi_unsigned takes two operands");
+
+ auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
+ if (!rhs)
+ return {};
+ auto rhsValue = rhs.getValue();
+
+ // x % 1 = 0
+ if (rhsValue.isOneValue())
+ return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
+
+ // Don't fold if it requires division by zero.
+ if (rhsValue.isNullValue())
+ return {};
+
+ auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+ if (!lhs)
+ return {};
+ return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
}
-LogicalResult
-SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
- // If the strides are dynamic return failure.
- if (getNumStrides())
- return failure();
+//===----------------------------------------------------------------------===//
+// ViewOp
+//===----------------------------------------------------------------------===//
- // When static, the stride operands can be retrieved by taking the strides of
- // the result of the subview op, and dividing the strides of the base memref.
- int64_t resultOffset, baseOffset;
- SmallVector<int64_t, 2> resultStrides, baseStrides;
- if (failed(
- getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) ||
- llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) ||
- failed(getStridesAndOffset(getType(), resultStrides, resultOffset)))
+static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType srcInfo;
+ SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
+ SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
+ auto indexType = parser.getBuilder().getIndexType();
+ Type srcType, dstType;
+ llvm::SMLoc offsetLoc;
+ if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
+ parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
return failure();
- assert(static_cast<int64_t>(resultStrides.size()) == getType().getRank() &&
- baseStrides.size() == resultStrides.size() &&
- "base and result memrefs must have the same rank");
- assert(!llvm::is_contained(resultStrides,
- MemRefType::getDynamicStrideOrOffset()) &&
- "strides of subview op must be static, when there are no dynamic "
- "strides specified");
- staticStrides.resize(getType().getRank());
- for (auto resultStride : enumerate(resultStrides)) {
- auto baseStride = baseStrides[resultStride.index()];
- // The result stride is expected to be a multiple of the base stride. Abort
- // if that is not the case.
- if (resultStride.value() < baseStride ||
- resultStride.value() % baseStride != 0)
+ if (offsetInfo.size() > 1)
+ return parser.emitError(offsetLoc) << "expects 0 or 1 offset operand";
+
+ return failure(
+ parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(srcType) ||
+ parser.resolveOperand(srcInfo, srcType, result.operands) ||
+ parser.resolveOperands(offsetInfo, indexType, result.operands) ||
+ parser.resolveOperands(sizesInfo, indexType, result.operands) ||
+ parser.parseKeywordType("to", dstType) ||
+ parser.addTypeToList(dstType, result.types));
+}
+
+static void print(OpAsmPrinter &p, ViewOp op) {
+ p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
+ auto dynamicOffset = op.getDynamicOffset();
+ if (dynamicOffset != nullptr)
+ p.printOperand(dynamicOffset);
+ p << "][" << op.getDynamicSizes() << ']';
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.getOperand(0).getType() << " to " << op.getType();
+}
+
+Value ViewOp::getDynamicOffset() {
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto result =
+ succeeded(mlir::getStridesAndOffset(getType(), strides, offset));
+ assert(result);
+ if (result && offset == MemRefType::getDynamicStrideOrOffset())
+ return getOperand(1);
+ return nullptr;
+}
+
+static LogicalResult verifyDynamicStrides(MemRefType memrefType,
+ ArrayRef<int64_t> strides) {
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ unsigned rank = memrefType.getRank();
+ assert(rank == strides.size());
+ bool dynamicStrides = false;
+ for (int i = rank - 2; i >= 0; --i) {
+ // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag.
+ if (ShapedType::isDynamic(shape[i + 1]))
+ dynamicStrides = true;
+ // If stride at dim 'i' is not dynamic, return error.
+ if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset())
return failure();
- staticStrides[resultStride.index()] = resultStride.value() / baseStride;
}
return success();
}
-//===----------------------------------------------------------------------===//
-// AssumeAlignmentOp
-//===----------------------------------------------------------------------===//
+static LogicalResult verify(ViewOp op) {
+ auto baseType = op.getOperand(0).getType().cast<MemRefType>();
+ auto viewType = op.getResult().getType().cast<MemRefType>();
-static LogicalResult verify(AssumeAlignmentOp op) {
- unsigned alignment = op.alignment().getZExtValue();
- if (!llvm::isPowerOf2_32(alignment))
- return op.emitOpError("alignment must be power of 2");
+ // The base memref should have identity layout map (or none).
+ if (baseType.getAffineMaps().size() > 1 ||
+ (baseType.getAffineMaps().size() == 1 &&
+ !baseType.getAffineMaps()[0].isIdentity()))
+ return op.emitError("unsupported map for base memref type ") << baseType;
+
+ // The base memref and the view memref should be in the same memory space.
+ if (baseType.getMemorySpace() != viewType.getMemorySpace())
+ return op.emitError("
diff erent memory spaces specified for base memref "
+ "type ")
+ << baseType << " and view memref type " << viewType;
+
+ // Verify that the result memref type has a strided layout map.
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(getStridesAndOffset(viewType, strides, offset)))
+ return op.emitError("result type ") << viewType << " is not strided";
+
+ // Verify that we have the correct number of operands for the result type.
+ unsigned memrefOperandCount = 1;
+ unsigned numDynamicDims = viewType.getNumDynamicDims();
+ unsigned dynamicOffsetCount =
+ offset == MemRefType::getDynamicStrideOrOffset() ? 1 : 0;
+ if (op.getNumOperands() !=
+ memrefOperandCount + numDynamicDims + dynamicOffsetCount)
+ return op.emitError("incorrect number of operands for type ") << viewType;
+
+ // Verify dynamic strides symbols were added to correct dimensions based
+ // on dynamic sizes.
+ if (failed(verifyDynamicStrides(viewType, strides)))
+ return op.emitError("incorrect dynamic strides in view memref type ")
+ << viewType;
return success();
}
namespace {
-/// Pattern to rewrite a subview op with constant size arguments.
-class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> {
-public:
- using OpRewritePattern<SubViewOp>::OpRewritePattern;
+struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
+ using OpRewritePattern<ViewOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
+ PatternMatchResult matchAndRewrite(ViewOp viewOp,
PatternRewriter &rewriter) const override {
- MemRefType subViewType = subViewOp.getType();
- // Follow all or nothing approach for shapes for now. If all the operands
- // for sizes are constants then fold it into the type of the result memref.
- if (subViewType.hasStaticShape() ||
- llvm::any_of(subViewOp.sizes(), [](Value operand) {
- return !matchPattern(operand, m_ConstantIndex());
- })) {
+ // Return if none of the operands are constants.
+ if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
+ return matchPattern(operand, m_ConstantIndex());
+ }))
return matchFailure();
- }
- SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
- for (auto size : llvm::enumerate(subViewOp.sizes())) {
- auto defOp = size.value().getDefiningOp();
- assert(defOp);
- staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
- }
- MemRefType newMemRefType =
- MemRefType::Builder(subViewType).setShape(staticShape);
- auto newSubViewOp = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
- ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
- // Insert a memref_cast for compatibility of the uses of the op.
- rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
- subViewOp.getType());
- return matchSuccess();
- }
-};
-// Pattern to rewrite a subview op with constant stride arguments.
-class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> {
-public:
- using OpRewritePattern<SubViewOp>::OpRewritePattern;
+ // Get result memref type.
+ auto memrefType = viewOp.getType();
+ if (memrefType.getAffineMaps().size() > 1)
+ return matchFailure();
+ auto map = memrefType.getAffineMaps().empty()
+ ? AffineMap::getMultiDimIdentityMap(memrefType.getRank(),
+ rewriter.getContext())
+ : memrefType.getAffineMaps()[0];
- PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
- PatternRewriter &rewriter) const override {
- if (subViewOp.getNumStrides() == 0) {
+ // Get offset from old memref view type 'memRefType'.
+ int64_t oldOffset;
+ SmallVector<int64_t, 4> oldStrides;
+ if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
return matchFailure();
+
+ SmallVector<Value, 4> newOperands;
+
+ // Fold dynamic offset operand if it is produced by a constant.
+ auto dynamicOffset = viewOp.getDynamicOffset();
+ int64_t newOffset = oldOffset;
+ unsigned dynamicOffsetOperandCount = 0;
+ if (dynamicOffset != nullptr) {
+ auto *defOp = dynamicOffset.getDefiningOp();
+ if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
+ // Dynamic offset will be folded into the map.
+ newOffset = constantIndexOp.getValue();
+ } else {
+ // Unable to fold dynamic offset. Add it to 'newOperands' list.
+ newOperands.push_back(dynamicOffset);
+ dynamicOffsetOperandCount = 1;
+ }
}
- // Follow all or nothing approach for strides for now. If all the operands
- // for strides are constants then fold it into the strides of the result
- // memref.
- int64_t baseOffset, resultOffset;
- SmallVector<int64_t, 4> baseStrides, resultStrides;
- MemRefType subViewType = subViewOp.getType();
- if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides,
- baseOffset)) ||
- failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
- llvm::is_contained(baseStrides,
- MemRefType::getDynamicStrideOrOffset()) ||
- llvm::any_of(subViewOp.strides(), [](Value stride) {
- return !matchPattern(stride, m_ConstantIndex());
- })) {
- return matchFailure();
+
+ // Fold any dynamic dim operands which are produced by a constant.
+ SmallVector<int64_t, 4> newShapeConstants;
+ newShapeConstants.reserve(memrefType.getRank());
+
+ unsigned dynamicDimPos = viewOp.getDynamicSizesOperandStart();
+ unsigned rank = memrefType.getRank();
+ for (unsigned dim = 0, e = rank; dim < e; ++dim) {
+ int64_t dimSize = memrefType.getDimSize(dim);
+ // If this is already static dimension, keep it.
+ if (!ShapedType::isDynamic(dimSize)) {
+ newShapeConstants.push_back(dimSize);
+ continue;
+ }
+ auto *defOp = viewOp.getOperand(dynamicDimPos).getDefiningOp();
+ if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
+ // Dynamic shape dimension will be folded.
+ newShapeConstants.push_back(constantIndexOp.getValue());
+ } else {
+ // Dynamic shape dimension not folded; copy operand from old memref.
+ newShapeConstants.push_back(dimSize);
+ newOperands.push_back(viewOp.getOperand(dynamicDimPos));
+ }
+ dynamicDimPos++;
}
- SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
- for (auto stride : llvm::enumerate(subViewOp.strides())) {
- auto defOp = stride.value().getDefiningOp();
- assert(defOp);
- assert(baseStrides[stride.index()] > 0);
- staticStrides[stride.index()] =
- cast<ConstantIndexOp>(defOp).getValue() * baseStrides[stride.index()];
+ // Compute new strides based on 'newShapeConstants'.
+ SmallVector<int64_t, 4> newStrides(rank);
+ newStrides[rank - 1] = 1;
+ bool dynamicStrides = false;
+ for (int i = rank - 2; i >= 0; --i) {
+ if (ShapedType::isDynamic(newShapeConstants[i + 1]))
+ dynamicStrides = true;
+ if (dynamicStrides)
+ newStrides[i] = MemRefType::getDynamicStrideOrOffset();
+ else
+ newStrides[i] = newShapeConstants[i + 1] * newStrides[i + 1];
}
- AffineMap layoutMap = makeStridedLinearLayoutMap(
- staticStrides, resultOffset, rewriter.getContext());
- MemRefType newMemRefType =
- MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
- auto newSubViewOp = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
- subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
- // Insert a memref_cast for compatibility of the uses of the op.
- rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
- subViewOp.getType());
+
+ // Regenerate strided layout map with 'newStrides' and 'newOffset'.
+ map = makeStridedLinearLayoutMap(newStrides, newOffset,
+ rewriter.getContext());
+
+ // Create new memref type with constant folded dims and/or offset/strides.
+ MemRefType newMemRefType = MemRefType::Builder(memrefType)
+ .setShape(newShapeConstants)
+ .setAffineMaps({map});
+ (void)dynamicOffsetOperandCount; // unused in opt mode
+ assert(static_cast<int64_t>(newOperands.size()) ==
+ dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims());
+
+ // Create new ViewOp.
+ auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
+ viewOp.getOperand(0), newOperands);
+ // Insert a cast so we have the same type as the old memref type.
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
+ viewOp.getType());
return matchSuccess();
}
};
-// Pattern to rewrite a subview op with constant offset arguments.
-class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
-public:
- using OpRewritePattern<SubViewOp>::OpRewritePattern;
+struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
+ using OpRewritePattern<ViewOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
+ PatternMatchResult matchAndRewrite(ViewOp viewOp,
PatternRewriter &rewriter) const override {
- if (subViewOp.getNumOffsets() == 0) {
+ Value memrefOperand = viewOp.getOperand(0);
+ MemRefCastOp memrefCastOp =
+ dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp());
+ if (!memrefCastOp)
return matchFailure();
- }
- // Follow all or nothing approach for offsets for now. If all the operands
- // for offsets are constants then fold it into the offset of the result
- // memref.
- int64_t baseOffset, resultOffset;
- SmallVector<int64_t, 4> baseStrides, resultStrides;
- MemRefType subViewType = subViewOp.getType();
- if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides,
- baseOffset)) ||
- failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
- llvm::is_contained(baseStrides,
- MemRefType::getDynamicStrideOrOffset()) ||
- baseOffset == MemRefType::getDynamicStrideOrOffset() ||
- llvm::any_of(subViewOp.offsets(), [](Value stride) {
- return !matchPattern(stride, m_ConstantIndex());
- })) {
+ Value allocOperand = memrefCastOp.getOperand();
+ AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
+ if (!allocOp)
return matchFailure();
- }
-
- auto staticOffset = baseOffset;
- for (auto offset : llvm::enumerate(subViewOp.offsets())) {
- auto defOp = offset.value().getDefiningOp();
- assert(defOp);
- assert(baseStrides[offset.index()] > 0);
- staticOffset +=
- cast<ConstantIndexOp>(defOp).getValue() * baseStrides[offset.index()];
- }
-
- AffineMap layoutMap = makeStridedLinearLayoutMap(
- resultStrides, staticOffset, rewriter.getContext());
- MemRefType newMemRefType =
- MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
- auto newSubViewOp = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
- subViewOp.sizes(), subViewOp.strides(), newMemRefType);
- // Insert a memref_cast for compatibility of the uses of the op.
- rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
- subViewOp.getType());
+ rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
+ viewOp.operands());
return matchSuccess();
}
};
} // end anonymous namespace
-void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
- results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder,
- SubViewOpOffsetFolder>(context);
+void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// XOrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
+ /// xor(x, 0) -> x
+ if (matchPattern(rhs(), m_Zero()))
+ return lhs();
+ /// xor(x,x) -> 0
+ if (lhs() == rhs())
+ return Builder(getContext()).getZeroAttr(getType());
+
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a ^ b; });
}
//===----------------------------------------------------------------------===//
@@ -2479,71 +2544,6 @@ static LogicalResult verify(ZeroExtendIOp op) {
return success();
}
-//===----------------------------------------------------------------------===//
-// FPExtOp
-//===----------------------------------------------------------------------===//
-
-bool FPExtOp::areCastCompatible(Type a, Type b) {
- if (auto fa = a.dyn_cast<FloatType>())
- if (auto fb = b.dyn_cast<FloatType>())
- return fa.getWidth() < fb.getWidth();
- if (auto va = a.dyn_cast<VectorType>())
- if (auto vb = b.dyn_cast<VectorType>())
- return va.getShape().equals(vb.getShape()) &&
- areCastCompatible(va.getElementType(), vb.getElementType());
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// FPTruncOp
-//===----------------------------------------------------------------------===//
-
-bool FPTruncOp::areCastCompatible(Type a, Type b) {
- if (auto fa = a.dyn_cast<FloatType>())
- if (auto fb = b.dyn_cast<FloatType>())
- return fa.getWidth() > fb.getWidth();
- if (auto va = a.dyn_cast<VectorType>())
- if (auto vb = b.dyn_cast<VectorType>())
- return va.getShape().equals(vb.getShape()) &&
- areCastCompatible(va.getElementType(), vb.getElementType());
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// AtomicRMWOp
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verify(AtomicRMWOp op) {
- if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
- return op.emitOpError(
- "expects the number of subscripts to be equal to memref rank");
- switch (op.kind()) {
- case AtomicRMWKind::addf:
- case AtomicRMWKind::maxf:
- case AtomicRMWKind::minf:
- case AtomicRMWKind::mulf:
- if (!op.value().getType().isa<FloatType>())
- return op.emitOpError()
- << "with kind '" << stringifyAtomicRMWKind(op.kind())
- << "' expects a floating-point type";
- break;
- case AtomicRMWKind::addi:
- case AtomicRMWKind::maxs:
- case AtomicRMWKind::maxu:
- case AtomicRMWKind::mins:
- case AtomicRMWKind::minu:
- case AtomicRMWKind::muli:
- if (!op.value().getType().isa<IntegerType>())
- return op.emitOpError()
- << "with kind '" << stringifyAtomicRMWKind(op.kind())
- << "' expects an integer type";
- break;
- default:
- break;
- }
- return success();
-}
-
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list