[Mlir-commits] [mlir] f79f430 - Fold Tensor.extract_slice into a constant splat.

Okwan Kwon llvmlistbot at llvm.org
Tue Feb 22 13:40:09 PST 2022


Author: Okwan Kwon
Date: 2022-02-22T21:39:57Z
New Revision: f79f430d4b268429f96be95622facd2775b25624

URL: https://github.com/llvm/llvm-project/commit/f79f430d4b268429f96be95622facd2775b25624
DIFF: https://github.com/llvm/llvm-project/commit/f79f430d4b268429f96be95622facd2775b25624.diff

LOG: Fold Tensor.extract_slice into a constant splat.

Fold arith.extract_slice into arith.constant when the source is a constant
splat and the result type is statically shaped.

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 5399718f67582..4371a1cb088f9 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -655,6 +655,11 @@ class DenseElementsAttr : public Attribute {
   /// same total number of elements as well as element type.
   DenseElementsAttr reshape(ShapedType newType);
 
+  /// Return a new DenseElementsAttr that has the same data as the current
+  /// attribute, but with a 
diff erent shape for a splat type. The new type must
+  /// have the same element type.
+  DenseElementsAttr resizeSplat(ShapedType newType);
+
   /// Return a new DenseElementsAttr that has the same data as the current
   /// attribute, but has bitcast elements to 'newElType'. The new type must have
   /// the same bitwidth as the current element type.

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5edb620d5cc32..70aa7b5fe57f6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1227,7 +1227,12 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
   return {};
 }
 
-OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
+OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
+  if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
+    auto resultType = result().getType().cast<ShapedType>();
+    if (resultType.hasStaticShape())
+      return splat.resizeSplat(resultType);
+  }
   if (getSourceType() == getType() &&
       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
     return this->source();

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 79e80f7c1317c..6988d1f8e4c30 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -967,6 +967,18 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
 }
 
+DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
+  assert(isSplat() && "expected a splat type");
+
+  ShapedType curType = getType();
+  if (curType == newType)
+    return *this;
+
+  assert(newType.getElementType() == curType.getElementType() &&
+         "expected the same element type");
+  return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), true);
+}
+
 /// Return a new DenseElementsAttr that has the same data as the current
 /// attribute, but has bitcast elements such that it is now 'newType'. The new
 /// type must have the same shape and element types of the same bitwidth as the

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ce3db8d6039c2..22770c2e67342 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -621,6 +621,17 @@ func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>,
 
 // -----
 
+// CHECK-LABEL: func @fold_extract_constant_splat
+//   CHECK-NOT: tensor.extract_slice
+//       CHECK: arith.constant dense<42> : tensor<4x4xi32>
+func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
+  %cst = arith.constant dense<42> : tensor<1024x1024xi32>
+  %1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32>
+  return %1 : tensor<4x4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_overlapping_insert
 //  CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
 func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {


        


More information about the Mlir-commits mailing list