[Mlir-commits] [mlir] 670ae4b - [MLIR][Shape] Fold `shape.mul`

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 24 06:31:01 PDT 2020


Author: Frederik Gossen
Date: 2020-07-24T13:30:45Z
New Revision: 670ae4b6da874270aa0cd8ab32120c17b2eadb95

URL: https://github.com/llvm/llvm-project/commit/670ae4b6da874270aa0cd8ab32120c17b2eadb95
DIFF: https://github.com/llvm/llvm-project/commit/670ae4b6da874270aa0cd8ab32120c17b2eadb95.diff

LOG: [MLIR][Shape] Fold `shape.mul`

Implement constant folding for `shape.mul`.

Differential Revision: https://reviews.llvm.org/D84438

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 425cf917283b..797dc0bc0cb6 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -326,6 +326,7 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
   }];
 
   let verifier = [{ return ::verify(*this); }];
+  let hasFolder = 1;
 }
 
 def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2f641300c491..d2b0dbdedb05 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -695,6 +695,18 @@ static LogicalResult verify(MulOp op) {
   return success();
 }
 
+OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
+  auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
+  if (!lhs)
+    return nullptr;
+  auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
+  if (!rhs)
+    return nullptr;
+  APInt folded = lhs.getValue() * rhs.getValue();
+  Type indexTy = IndexType::get(getContext());
+  return IntegerAttr::get(indexTy, folded);
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeOfOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index b4dca5e3c2bf..577656a0b362 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -734,3 +734,43 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
   %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
   return %result : i1
 }
+
+// -----
+
+// Fold `mul` for constant sizes.
+// CHECK-LABEL: @fold_mul_size
+func @fold_mul_size() -> !shape.size {
+  // CHECK: %[[RESULT:.*]] = shape.const_size 6
+  // CHECK: return %[[RESULT]] : !shape.size
+  %c2 = shape.const_size 2
+  %c3 = shape.const_size 3
+  %result = shape.mul %c2, %c3 : !shape.size, !shape.size -> !shape.size
+  return %result : !shape.size
+}
+
+// -----
+
+// Fold `mul` for constant indices.
+// CHECK-LABEL: @fold_mul_index
+func @fold_mul_index() -> index {
+  // CHECK: %[[RESULT:.*]] = constant 6 : index
+  // CHECK: return %[[RESULT]] : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %result = shape.mul %c2, %c3 : index, index -> index
+  return %result : index
+}
+
+// -----
+
+// Fold `mul` for mixed constants.
+// CHECK-LABEL: @fold_mul_mixed
+func @fold_mul_mixed() -> !shape.size {
+  // CHECK: %[[RESULT:.*]] = shape.const_size 6
+  // CHECK: return %[[RESULT]] : !shape.size
+  %c2 = shape.const_size 2
+  %c3 = constant 3 : index
+  %result = shape.mul %c2, %c3 : !shape.size, index -> !shape.size
+  return %result : !shape.size
+}
+


        


More information about the Mlir-commits mailing list