[Mlir-commits] [mlir] [mlir][spirv] Add folding for IAddCarry/[S|U]MulExtended (PR #73340)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Nov 24 13:41:25 PST 2023
================
@@ -115,6 +115,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
results.add<CombineChainedAccessChain>(context);
}
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // iaddcarry (x, 0) = <0, x>
+ if (matchPattern(operands[1], m_Zero())) {
+ constituents.push_back(operands[1]);
+ constituents.push_back(operands[0]);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits (full component width) of
+ // the addition.
+ //
+ // Member 1 of the result gets the high-order (carry) bit of the result of
+ // the addition. That is, it gets the value 1 if the addition overflowed
+ // the component width, and 0 otherwise.
+ Attribute lhs;
+ Attribute rhs;
+ if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+ !matchPattern(operands[1], m_Constant(&rhs)))
+ return failure();
+
+ auto adds = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+ if (!adds)
+ return failure();
+
+ auto carrys = constFoldBinaryOp<IntegerAttr>(
+ ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return a.ult(b) ? (zero + 1) : zero;
+ });
+
+ if (!carrys)
+ return failure();
+
+ Value addsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+ constituents.push_back(addsVal);
+
+ Value carrysVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+ constituents.push_back(carrysVal);
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[S|U]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+ using OpRewritePattern<MulOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MulOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // [su]mulextended (x, 0) = <0, 0>
+ if (matchPattern(operands[1], m_Zero())) {
+ Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+ constituents.push_back(zero);
+ constituents.push_back(zero);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits of the multiplication.
+ //
+ // Member 1 of the result gets the high-order bits of the multiplication.
+ Attribute lhs;
+ Attribute rhs;
+ if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+ !matchPattern(operands[1], m_Constant(&rhs)))
+ return failure();
+
+ auto lowBits = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+ if (!lowBits)
+ return failure();
+
+ auto highBits = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) {
+ unsigned bitWidth = a.getBitWidth();
+ APInt c;
+ if (IsSigned) {
+ c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+ } else {
+ c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+ }
+ return c.extractBits(bitWidth, bitWidth); // Extract high result
+ });
+
+ if (!highBits)
+ return failure();
+
+ Value lowBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+ constituents.push_back(lowBitsVal);
+
+ Value highBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+ constituents.push_back(highBitsVal);
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
----------------
kuhar wrote:
Same here
https://github.com/llvm/llvm-project/pull/73340
More information about the Mlir-commits
mailing list