[Mlir-commits] [mlir] f79f430 - Fold Tensor.extract_slice into a constant splat.
Okwan Kwon
llvmlistbot at llvm.org
Tue Feb 22 13:40:09 PST 2022
Author: Okwan Kwon
Date: 2022-02-22T21:39:57Z
New Revision: f79f430d4b268429f96be95622facd2775b25624
URL: https://github.com/llvm/llvm-project/commit/f79f430d4b268429f96be95622facd2775b25624
DIFF: https://github.com/llvm/llvm-project/commit/f79f430d4b268429f96be95622facd2775b25624.diff
LOG: Fold Tensor.extract_slice into a constant splat.
Fold arith.extract_slice into arith.constant when the source is a constant
splat and the result type is statically shaped.
Added:
Modified:
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 5399718f67582..4371a1cb088f9 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -655,6 +655,11 @@ class DenseElementsAttr : public Attribute {
/// same total number of elements as well as element type.
DenseElementsAttr reshape(ShapedType newType);
+ /// Return a new DenseElementsAttr that has the same data as the current
+ /// attribute, but with a
diff erent shape for a splat type. The new type must
+ /// have the same element type.
+ DenseElementsAttr resizeSplat(ShapedType newType);
+
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has bitcast elements to 'newElType'. The new type must have
/// the same bitwidth as the current element type.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5edb620d5cc32..70aa7b5fe57f6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1227,7 +1227,12 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
return {};
}
-OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
+OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
+ if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
+ auto resultType = result().getType().cast<ShapedType>();
+ if (resultType.hasStaticShape())
+ return splat.resizeSplat(resultType);
+ }
if (getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
return this->source();
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 79e80f7c1317c..6988d1f8e4c30 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -967,6 +967,18 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
}
+DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
+ assert(isSplat() && "expected a splat type");
+
+ ShapedType curType = getType();
+ if (curType == newType)
+ return *this;
+
+ assert(newType.getElementType() == curType.getElementType() &&
+ "expected the same element type");
+ return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), true);
+}
+
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has bitcast elements such that it is now 'newType'. The new
/// type must have the same shape and element types of the same bitwidth as the
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ce3db8d6039c2..22770c2e67342 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -621,6 +621,17 @@ func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>,
// -----
+// CHECK-LABEL: func @fold_extract_constant_splat
+// CHECK-NOT: tensor.extract_slice
+// CHECK: arith.constant dense<42> : tensor<4x4xi32>
+func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
+ %cst = arith.constant dense<42> : tensor<1024x1024xi32>
+ %1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32>
+ return %1 : tensor<4x4xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_overlapping_insert
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
More information about the Mlir-commits
mailing list