[Mlir-commits] [mlir] 14028ec - [mlir][spirv] Add canon patterns for IAddCarry/[S|U]MulExtended (#73340)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 29 11:32:19 PST 2023
Author: Finn Plummer
Date: 2023-11-29T14:32:13-05:00
New Revision: 14028ec0a62210d68a4dd7a046ac79c8c3b7727e
URL: https://github.com/llvm/llvm-project/commit/14028ec0a62210d68a4dd7a046ac79c8c3b7727e
DIFF: https://github.com/llvm/llvm-project/commit/14028ec0a62210d68a4dd7a046ac79c8c3b7727e.diff
LOG: [mlir][spirv] Add canon patterns for IAddCarry/[S|U]MulExtended (#73340)
Add missing constant propogation folder for IAddCarry and
[S|U]MulExtended. Due to currently missing constant value for
spirv.struct the folding is done using canonicalization patterns.
Implement additional folding when rhs is 0 for all ops and when rhs is 1
for UMulExt.
This helps for readability of lowered code into SPIR-V.
Part of work for #70704
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 701389d1cf4c1ec..51124e141c6d469 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -316,6 +316,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
%2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
@@ -551,6 +553,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
%2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
@@ -675,6 +679,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
%2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 82af41643edb89d..22cb9bf718e36f4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -122,6 +122,200 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
results.add<CombineChainedAccessChain>(context);
}
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value lhs = op.getOperand1();
+ Value rhs = op.getOperand2();
+ Type constituentType = lhs.getType();
+
+ // iaddcarry (x, 0) = <0, x>
+ if (matchPattern(rhs, m_Zero())) {
+ Value constituents[2] = {rhs, lhs};
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits (full component width) of
+ // the addition.
+ //
+ // Member 1 of the result gets the high-order (carry) bit of the result of
+ // the addition. That is, it gets the value 1 if the addition overflowed
+ // the component width, and 0 otherwise.
+ Attribute lhsAttr;
+ Attribute rhsAttr;
+ if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
+ !matchPattern(rhs, m_Constant(&rhsAttr)))
+ return failure();
+
+ auto adds = constFoldBinaryOp<IntegerAttr>(
+ {lhsAttr, rhsAttr},
+ [](const APInt &a, const APInt &b) { return a + b; });
+ if (!adds)
+ return failure();
+
+ auto carrys = constFoldBinaryOp<IntegerAttr>(
+ ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return a.ult(b) ? (zero + 1) : zero;
+ });
+
+ if (!carrys)
+ return failure();
+
+ Value addsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+
+ Value carrysVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+
+ // Create empty struct
+ Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
+ // Fill in adds at id 0
+ Value intermediate =
+ rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
+ // Fill in carrys at id 1
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
+ intermediate, 1);
+ return success();
+ }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[S|U]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+ using OpRewritePattern<MulOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MulOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value lhs = op.getOperand1();
+ Value rhs = op.getOperand2();
+ Type constituentType = lhs.getType();
+
+ // [su]mulextended (x, 0) = <0, 0>
+ if (matchPattern(rhs, m_Zero())) {
+ Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+ Value constituents[2] = {zero, zero};
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits of the multiplication.
+ //
+ // Member 1 of the result gets the high-order bits of the multiplication.
+ Attribute lhsAttr;
+ Attribute rhsAttr;
+ if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
+ !matchPattern(rhs, m_Constant(&rhsAttr)))
+ return failure();
+
+ auto lowBits = constFoldBinaryOp<IntegerAttr>(
+ {lhsAttr, rhsAttr},
+ [](const APInt &a, const APInt &b) { return a * b; });
+
+ if (!lowBits)
+ return failure();
+
+ auto highBits = constFoldBinaryOp<IntegerAttr>(
+ {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
+ unsigned bitWidth = a.getBitWidth();
+ APInt c;
+ if (IsSigned) {
+ c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+ } else {
+ c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+ }
+ return c.extractBits(bitWidth, bitWidth); // Extract high result
+ });
+
+ if (!highBits)
+ return failure();
+
+ Value lowBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+
+ Value highBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+
+ // Create empty struct
+ Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
+ // Fill in lowBits at id 0
+ Value intermediate =
+ rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
+ // Fill in highBits at id 1
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
+ intermediate, 1);
+ return success();
+ }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value lhs = op.getOperand1();
+ Value rhs = op.getOperand2();
+ Type constituentType = lhs.getType();
+
+ // umulextended (x, 1) = <x, 0>
+ if (matchPattern(rhs, m_One())) {
+ Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+ Value constituents[2] = {lhs, zero};
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
+void spirv::UMulExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
+}
+
//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 6fb5ca5c41839a0..867ddf3c8017336 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -336,6 +336,61 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iaddcarry_x_0
+func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ // CHECK: %[[RET:.*]] = spirv.CompositeConstruct
+ %c0 = spirv.Constant 0 : i32
+ %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
+
+ // CHECK: return %[[RET]]
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iaddcarry
+func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
+ // CHECK-DAG: %[[CN3:.*]] = spirv.Constant -3
+ // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
+ // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN3]], %[[UNDEF1]][0 : i32]
+ // CHECK-DAG: %[[CC_CN3_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER1]][1 : i32]
+ // CHECK-DAG: %[[C1:.*]] = spirv.Constant 1
+ // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
+ // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
+ // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[CN13]], %[[UNDEF2]][0 : i32]
+ // CHECK-DAG: %[[CC_CN13_C1:.*]] = spirv.CompositeInsert %[[C1]], %[[INTER2]][1 : i32]
+ %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ // CHECK: return %[[CC_CN3_C0]], %[[CC_CN13_C1]]
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_iaddcarry
+func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
+
+ // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[-3, -11, 0]>
+ // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[0, 1, 1]>
+ // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
+ // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32]
+ // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32]
+ %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+ // CHECK: return %[[CC_CV1_CV2]]
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.IMul
//===----------------------------------------------------------------------===//
@@ -400,6 +455,133 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smulextended_x_0
+func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ // CHECK: %[[C0:.*]] = spirv.Constant 0
+ // CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[C0]], %[[C0]]
+ %c0 = spirv.Constant 0 : i32
+ %0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+
+ // CHECK: return %[[RET]]
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_smulextended
+func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
+ // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1
+ // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
+ // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32]
+ // CHECK-DAG: %[[CC_CN40_CN1:.*]] = spirv.CompositeInsert %[[CN1]], %[[INTER1]]
+ // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
+ // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
+ // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
+ // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32]
+ // CHECK-DAG: %[[CC_C40_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER2]][1 : i32]
+ %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ // CHECK: return %[[CC_CN40_CN1]], %[[CC_C40_C0]]
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_smulextended
+func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+ // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
+ // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, 0, -1]>
+ // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
+ // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32]
+ // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32]
+ %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+ // CHECK: return %[[CC_CV1_CV2]]
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umulextended_x_0
+func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ // CHECK: %[[C0:.*]] = spirv.Constant 0
+ // CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[C0]], %[[C0]]
+ %c0 = spirv.Constant 0 : i32
+ %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+
+ // CHECK: return %[[RET]]
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @umulextended_x_1
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ // CHECK: %[[C0:.*]] = spirv.Constant 0
+ // CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[ARG]], %[[C0]]
+ %c0 = spirv.Constant 1 : i32
+ %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+
+ // CHECK: return %[[RET]]
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_umulextended
+func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+
+ // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
+ // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
+ // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
+ // CHECK-DAG: %[[C4:.*]] = spirv.Constant 4
+ // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
+ // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32]
+ // CHECK-DAG: %[[CC_CN40_C4:.*]] = spirv.CompositeInsert %[[C4]], %[[INTER1]][1 : i32]
+ // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
+ // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32]
+ // CHECK-DAG: %[[CC_C40_CN13:.*]] = spirv.CompositeInsert %[[CN13]], %[[INTER2]][1 : i32]
+ %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ // CHECK: return %[[CC_CN40_C4]], %[[CC_C40_CN13]]
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_umulextended
+func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+ // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
+ // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, -13, 0]>
+ // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
+ // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]]
+ // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]]
+ %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+ // CHECK: return %[[CC_CV1_CV2]]
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// -----
+
+
//===----------------------------------------------------------------------===//
// spirv.ISub
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list