[Mlir-commits] [mlir] 3648065 - [mlir][vector] Add canonicalization patterns for ExtractStride/ShapeCast + Splat constant

Thomas Raoux llvmlistbot at llvm.org
Tue Nov 3 11:30:05 PST 2020


Author: Thomas Raoux
Date: 2020-11-03T11:29:54-08:00
New Revision: 36480657d8ce97836f76bf5fa8c36677b9cdc19a

URL: https://github.com/llvm/llvm-project/commit/36480657d8ce97836f76bf5fa8c36677b9cdc19a
DIFF: https://github.com/llvm/llvm-project/commit/36480657d8ce97836f76bf5fa8c36677b9cdc19a.diff

LOG: [mlir][vector] Add canonicalization patterns for ExtractStride/ShapeCast + Splat constant

Differential Revision: https://reviews.llvm.org/D90567

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 604c2994a059..2ddd06ccf44f 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1649,6 +1649,7 @@ def Vector_ShapeCastOp :
   }];
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Vector_BitCastOp :

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 53cdf3fc9103..04b8b757a14b 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1770,13 +1770,39 @@ class StridedSliceConstantMaskFolder final
   }
 };
 
+// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
+class StridedSliceConstantFolder final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+                                PatternRewriter &rewriter) const override {
+    // Return if 'extractStridedSliceOp' operand is not defined by a
+    // ConstantOp.
+    auto constantOp =
+        extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
+    if (!constantOp)
+      return failure();
+    auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
+    if (!dense)
+      return failure();
+    auto newAttr = DenseElementsAttr::get(
+        extractStridedSliceOp.getType().cast<VectorType>(),
+        dense.getSplatValue());
+    rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
+    return success();
+  }
+};
+
 } // end anonymous namespace
 
 void ExtractStridedSliceOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
-  // ConstantMaskOp.
-  results.insert<StridedSliceConstantMaskFolder>(context);
+  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
+  results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2560,6 +2586,36 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+namespace {
+// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
+class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
+public:
+  using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
+    if (!constantOp)
+      return failure();
+    // Only handle splat for now.
+    auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
+    if (!dense)
+      return failure();
+    auto newAttr = DenseElementsAttr::get(
+        shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
+    rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
+    return success();
+  }
+};
+
+} // namespace
+
+void ShapeCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                              MLIRContext *context) {
+  // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
+  results.insert<ShapeCastConstantFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // VectorBitCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 00905420c118..f07285d7d98c 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -580,3 +580,37 @@ func @broadcast_folding2() -> vector<4x16xi32> {
   %2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
   return %2 : vector<4x16xi32>
 }
+
+// -----
+
+// CHECK-LABEL: shape_cast_constant
+//       CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32>
+//       CHECK: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32>
+//       CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
+func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+  %cst = constant dense<2.000000e+00> : vector<5x4x2xf32>
+  %cst_1 = constant dense<1> : vector<12x2xi32>
+  %0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
+  %1 = vector.shape_cast %cst_1 : vector<12x2xi32> to vector<3x4x2xi32>
+  return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_constant
+//       CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<12x2xf32>
+//       CHECK: %[[CST1:.*]] = constant dense<1> : vector<2x13x3xi32>
+//       CHECK: return %[[CST0]], %[[CST1]] : vector<12x2xf32>, vector<2x13x3xi32>
+func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) {
+  %cst = constant dense<2.000000e+00> : vector<29x7xf32>
+  %cst_1 = constant dense<1> : vector<4x37x9xi32>
+  %0 = vector.extract_strided_slice %cst
+    {offsets = [2, 3], sizes = [12, 2], strides = [1, 1]}
+      : vector<29x7xf32> to vector<12x2xf32>
+  %1 = vector.extract_strided_slice %cst_1
+    {offsets = [1, 2, 5], sizes = [2, 13, 3], strides = [1, 1, 1]}
+      : vector<4x37x9xi32> to vector<2x13x3xi32>
+  return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32>
+}
+
+


        


More information about the Mlir-commits mailing list