[Mlir-commits] [mlir] 3648065 - [mlir][vector] Add canonicalization patterns for ExtractStride/ShapeCast + Splat constant
Thomas Raoux
llvmlistbot at llvm.org
Tue Nov 3 11:30:05 PST 2020
Author: Thomas Raoux
Date: 2020-11-03T11:29:54-08:00
New Revision: 36480657d8ce97836f76bf5fa8c36677b9cdc19a
URL: https://github.com/llvm/llvm-project/commit/36480657d8ce97836f76bf5fa8c36677b9cdc19a
DIFF: https://github.com/llvm/llvm-project/commit/36480657d8ce97836f76bf5fa8c36677b9cdc19a.diff
LOG: [mlir][vector] Add canonicalization patterns for ExtractStride/ShapeCast + Splat constant
Differential Revision: https://reviews.llvm.org/D90567
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 604c2994a059..2ddd06ccf44f 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1649,6 +1649,7 @@ def Vector_ShapeCastOp :
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def Vector_BitCastOp :
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 53cdf3fc9103..04b8b757a14b 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1770,13 +1770,39 @@ class StridedSliceConstantMaskFolder final
}
};
+// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
+class StridedSliceConstantFolder final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+ PatternRewriter &rewriter) const override {
+ // Return if 'extractStridedSliceOp' operand is not defined by a
+ // ConstantOp.
+ auto constantOp =
+ extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
+ if (!constantOp)
+ return failure();
+ auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
+ if (!dense)
+ return failure();
+ auto newAttr = DenseElementsAttr::get(
+ extractStridedSliceOp.getType().cast<VectorType>(),
+ dense.getSplatValue());
+ rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
+ return success();
+ }
+};
+
} // end anonymous namespace
void ExtractStridedSliceOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
- // ConstantMaskOp.
- results.insert<StridedSliceConstantMaskFolder>(context);
+ // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
+ results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
+ context);
}
//===----------------------------------------------------------------------===//
@@ -2560,6 +2586,36 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+namespace {
+// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
+class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
+public:
+ using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
+ if (!constantOp)
+ return failure();
+ // Only handle splat for now.
+ auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
+ if (!dense)
+ return failure();
+ auto newAttr = DenseElementsAttr::get(
+ shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
+ rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
+ return success();
+ }
+};
+
+} // namespace
+
+void ShapeCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
+ results.insert<ShapeCastConstantFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// VectorBitCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 00905420c118..f07285d7d98c 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -580,3 +580,37 @@ func @broadcast_folding2() -> vector<4x16xi32> {
%2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
return %2 : vector<4x16xi32>
}
+
+// -----
+
+// CHECK-LABEL: shape_cast_constant
+// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32>
+// CHECK: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32>
+// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
+func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+ %cst = constant dense<2.000000e+00> : vector<5x4x2xf32>
+ %cst_1 = constant dense<1> : vector<12x2xi32>
+ %0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
+ %1 = vector.shape_cast %cst_1 : vector<12x2xi32> to vector<3x4x2xi32>
+ return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_constant
+// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<12x2xf32>
+// CHECK: %[[CST1:.*]] = constant dense<1> : vector<2x13x3xi32>
+// CHECK: return %[[CST0]], %[[CST1]] : vector<12x2xf32>, vector<2x13x3xi32>
+func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) {
+ %cst = constant dense<2.000000e+00> : vector<29x7xf32>
+ %cst_1 = constant dense<1> : vector<4x37x9xi32>
+ %0 = vector.extract_strided_slice %cst
+ {offsets = [2, 3], sizes = [12, 2], strides = [1, 1]}
+ : vector<29x7xf32> to vector<12x2xf32>
+ %1 = vector.extract_strided_slice %cst_1
+ {offsets = [1, 2, 5], sizes = [2, 13, 3], strides = [1, 1, 1]}
+ : vector<4x37x9xi32> to vector<2x13x3xi32>
+ return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32>
+}
+
+
More information about the Mlir-commits
mailing list