[Mlir-commits] [mlir] b006902 - [mlir] Fold trivial subtensor / subtensor_insert ops.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Feb 18 13:41:54 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-18T21:34:55Z
New Revision: b006902b2dfac792e8ade73798ca1b216654faf7
URL: https://github.com/llvm/llvm-project/commit/b006902b2dfac792e8ade73798ca1b216654faf7
DIFF: https://github.com/llvm/llvm-project/commit/b006902b2dfac792e8ade73798ca1b216654faf7.diff
LOG: [mlir] Fold trivial subtensor / subtensor_insert ops.
Static subtensor / subtensor_insert of the same size as the source / destination tensor and root @[0..0] with strides [1..1] are folded away.
Differential revision: https://reviews.llvm.org/D96991
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index a7142d298b66..9a253f8e814a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -364,6 +364,12 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
/// comparison predicates.
bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
const APFloat &rhs);
+
+/// Return true if ofr1 and ofr2 are the same integer constant attribute values
+/// or the same SSA value.
+/// Ignore integer bitwitdh and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType have no bitwidth.
+bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
} // end namespace mlir
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 29863c82c502..64279c8fce3c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2928,6 +2928,7 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
}];
let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -3026,6 +3027,7 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 52b41ca305d1..e916de4c0658 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -59,6 +59,27 @@ static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
}
+/// Return true if ofr1 and ofr2 are the same integer constant attribute values
+/// or the same SSA value.
+/// Ignore integer bitwitdh and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType have no bitwidth.
+bool mlir::isEqualConstantIntOrValue(OpFoldResult op1, OpFoldResult op2) {
+ auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
+ Attribute attr = ofr.dyn_cast<Attribute>();
+ // Note: isa+cast-like pattern allows writing the condition below as 1 line.
+ if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
+ attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
+ if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
+ return intAttr.getValue().getSExtValue();
+ return llvm::None;
+ };
+ auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
+ if (cst1 && cst2 && *cst1 == *cst2)
+ return true;
+ auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
+ return v1 && v2 && v1 == v2;
+}
+
//===----------------------------------------------------------------------===//
// StandardOpsDialect Interfaces
//===----------------------------------------------------------------------===//
@@ -3557,6 +3578,34 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
context);
}
+//
+static LogicalResult
+foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
+ ShapedType shapedType) {
+ OpBuilder b(op.getContext());
+ for (OpFoldResult ofr : op.getMixedOffsets())
+ if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(0)))
+ return failure();
+ // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
+ // is appropriate.
+ auto shape = shapedType.getShape();
+ for (auto it : llvm::zip(op.getMixedSizes(), shape))
+ if (!isEqualConstantIntOrValue(std::get<0>(it),
+ b.getIndexAttr(std::get<1>(it))))
+ return failure();
+ for (OpFoldResult ofr : op.getMixedStrides())
+ if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(1)))
+ return failure();
+ return success();
+}
+
+OpFoldResult SubTensorOp::fold(ArrayRef<Attribute>) {
+ if (getSourceType() == getType() &&
+ succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
+ return this->source();
+ return OpFoldResult();
+}
+
//===----------------------------------------------------------------------===//
// SubTensorInsertOp
//===----------------------------------------------------------------------===//
@@ -3597,6 +3646,13 @@ void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
+OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
+ if (getSourceType() == getType() &&
+ succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
+ return this->source();
+ return OpFoldResult();
+}
+
//===----------------------------------------------------------------------===//
// TensorLoadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index c864af8f5747..9247152e8677 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -157,3 +157,22 @@ func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
}
+
+// CHECK-LABEL: func @trivial_subtensor
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
+// CHECK-NOT: subtensor
+// CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8>
+func @trivial_subtensor(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
+ %0 = subtensor %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<4x6x16x32xi8>
+ return %0 : tensor<4x6x16x32xi8>
+}
+
+// CHECK-LABEL: func @trivial_subtensor_insert
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
+// CHECK-NOT: subtensor
+// CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8>
+func @trivial_subtensor_insert(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
+ %0 = subtensor_insert %arg0 into %arg1[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<4x6x16x32xi8>
+ return %0 : tensor<4x6x16x32xi8>
+}
+
More information about the Mlir-commits
mailing list