[Mlir-commits] [mlir] 670ae4b - [MLIR][Shape] Fold `shape.mul`
Frederik Gossen
llvmlistbot at llvm.org
Fri Jul 24 06:31:01 PDT 2020
Author: Frederik Gossen
Date: 2020-07-24T13:30:45Z
New Revision: 670ae4b6da874270aa0cd8ab32120c17b2eadb95
URL: https://github.com/llvm/llvm-project/commit/670ae4b6da874270aa0cd8ab32120c17b2eadb95
DIFF: https://github.com/llvm/llvm-project/commit/670ae4b6da874270aa0cd8ab32120c17b2eadb95.diff
LOG: [MLIR][Shape] Fold `shape.mul`
Implement constant folding for `shape.mul`.
Differential Revision: https://reviews.llvm.org/D84438
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 425cf917283b..797dc0bc0cb6 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -326,6 +326,7 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
}];
let verifier = [{ return ::verify(*this); }];
+ let hasFolder = 1;
}
def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2f641300c491..d2b0dbdedb05 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -695,6 +695,18 @@ static LogicalResult verify(MulOp op) {
return success();
}
+OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
+ auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
+ if (!lhs)
+ return nullptr;
+ auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
+ if (!rhs)
+ return nullptr;
+ APInt folded = lhs.getValue() * rhs.getValue();
+ Type indexTy = IndexType::get(getContext());
+ return IntegerAttr::get(indexTy, folded);
+}
+
//===----------------------------------------------------------------------===//
// ShapeOfOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index b4dca5e3c2bf..577656a0b362 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -734,3 +734,43 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
return %result : i1
}
+
+// -----
+
+// Fold `mul` for constant sizes.
+// CHECK-LABEL: @fold_mul_size
+func @fold_mul_size() -> !shape.size {
+ // CHECK: %[[RESULT:.*]] = shape.const_size 6
+ // CHECK: return %[[RESULT]] : !shape.size
+ %c2 = shape.const_size 2
+ %c3 = shape.const_size 3
+ %result = shape.mul %c2, %c3 : !shape.size, !shape.size -> !shape.size
+ return %result : !shape.size
+}
+
+// -----
+
+// Fold `mul` for constant indices.
+// CHECK-LABEL: @fold_mul_index
+func @fold_mul_index() -> index {
+ // CHECK: %[[RESULT:.*]] = constant 6 : index
+ // CHECK: return %[[RESULT]] : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %result = shape.mul %c2, %c3 : index, index -> index
+ return %result : index
+}
+
+// -----
+
+// Fold `mul` for mixed constants.
+// CHECK-LABEL: @fold_mul_mixed
+func @fold_mul_mixed() -> !shape.size {
+ // CHECK: %[[RESULT:.*]] = shape.const_size 6
+ // CHECK: return %[[RESULT]] : !shape.size
+ %c2 = shape.const_size 2
+ %c3 = constant 3 : index
+ %result = shape.mul %c2, %c3 : !shape.size, index -> !shape.size
+ return %result : !shape.size
+}
+
More information about the Mlir-commits
mailing list