[Mlir-commits] [mlir] [mlir][spirv] Add folding for IAddCarry/[S|U]MulExtended (PR #73340)

Jakub Kuderski llvmlistbot at llvm.org
Fri Nov 24 13:41:26 PST 2023


================
@@ -115,6 +115,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
   results.add<CombineChainedAccessChain>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // iaddcarry (x, 0) = <0, x>
+    if (matchPattern(operands[1], m_Zero())) {
+      constituents.push_back(operands[1]);
+      constituents.push_back(operands[0]);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    //  Result Type must be from OpTypeStruct.  The struct must have two
+    //  members...
+    //
+    //  Member 0 of the result gets the low-order bits (full component width) of
+    //  the addition.
+    //
+    //  Member 1 of the result gets the high-order (carry) bit of the result of
+    //  the addition. That is, it gets the value 1 if the addition overflowed
+    //  the component width, and 0 otherwise.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto adds = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+    if (!adds)
+      return failure();
+
+    auto carrys = constFoldBinaryOp<IntegerAttr>(
+        ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
+          APInt zero = APInt::getZero(a.getBitWidth());
+          return a.ult(b) ? (zero + 1) : zero;
+        });
+
+    if (!carrys)
+      return failure();
+
+    Value addsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+    constituents.push_back(addsVal);
+
+    Value carrysVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+    constituents.push_back(carrysVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[S|U]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+  using OpRewritePattern<MulOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MulOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // [su]mulextended (x, 0) = <0, 0>
+    if (matchPattern(operands[1], m_Zero())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(zero);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    // Result Type must be from OpTypeStruct.  The struct must have two
+    // members...
+    //
+    // Member 0 of the result gets the low-order bits of the multiplication.
+    //
+    // Member 1 of the result gets the high-order bits of the multiplication.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto lowBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+    if (!lowBits)
+      return failure();
+
+    auto highBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) {
+          unsigned bitWidth = a.getBitWidth();
+          APInt c;
+          if (IsSigned) {
+            c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+          } else {
+            c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+          }
+          return c.extractBits(bitWidth, bitWidth); // Extract high result
+        });
+
+    if (!highBits)
+      return failure();
+
+    Value lowBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+    constituents.push_back(lowBitsVal);
+
+    Value highBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+    constituents.push_back(highBitsVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // umulextended (x, 1) = <x, 0>
+    if (matchPattern(operands[1], m_One())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(operands[0]);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
----------------
kuhar wrote:

Also here

https://github.com/llvm/llvm-project/pull/73340


More information about the Mlir-commits mailing list