fix: add PaymentProvider field to prevent cross-gateway callback attacks
EPay allows users to switch payment methods (e.g. wxpay→alipay) during checkout, causing callback rejection. Replace fragile blocklist guard with a PaymentProvider field on TopUp and SubscriptionOrder that identifies which gateway created the order.
This commit is contained in:
@@ -36,30 +36,32 @@ func insertSubscriptionPlanForPaymentGuardTest(t *testing.T, id int) *Subscripti
|
||||
return plan
|
||||
}
|
||||
|
||||
func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentMethod string) {
|
||||
func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentProvider string) {
|
||||
t.Helper()
|
||||
order := &SubscriptionOrder{
|
||||
UserId: userID,
|
||||
PlanId: planID,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentMethod,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
UserId: userID,
|
||||
PlanId: planID,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentProvider,
|
||||
PaymentProvider: paymentProvider,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
}
|
||||
require.NoError(t, order.Insert())
|
||||
}
|
||||
|
||||
func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentMethod string) {
|
||||
func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentProvider string) {
|
||||
t.Helper()
|
||||
topUp := &TopUp{
|
||||
UserId: userID,
|
||||
Amount: 2,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentMethod,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
UserId: userID,
|
||||
Amount: 2,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentProvider,
|
||||
PaymentProvider: paymentProvider,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
}
|
||||
require.NoError(t, topUp.Insert())
|
||||
}
|
||||
@@ -89,7 +91,7 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 101, 0)
|
||||
insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentMethodStripe)
|
||||
insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentProviderStripe)
|
||||
|
||||
err := RechargeWaffoPancake("waffo-pancake-guard")
|
||||
require.Error(t, err)
|
||||
@@ -100,27 +102,27 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
assert.Equal(t, 0, getUserQuotaForPaymentGuardTest(t, 101))
|
||||
}
|
||||
|
||||
func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
tradeNo string
|
||||
storedPaymentMethod string
|
||||
expectedPaymentMethod string
|
||||
targetStatus string
|
||||
name string
|
||||
tradeNo string
|
||||
storedPaymentProvider string
|
||||
expectedPaymentProvider string
|
||||
targetStatus string
|
||||
}{
|
||||
{
|
||||
name: "stripe expire",
|
||||
tradeNo: "stripe-expire-guard",
|
||||
storedPaymentMethod: PaymentMethodCreem,
|
||||
expectedPaymentMethod: PaymentMethodStripe,
|
||||
targetStatus: common.TopUpStatusExpired,
|
||||
name: "stripe expire",
|
||||
tradeNo: "stripe-expire-guard",
|
||||
storedPaymentProvider: PaymentProviderCreem,
|
||||
expectedPaymentProvider: PaymentProviderStripe,
|
||||
targetStatus: common.TopUpStatusExpired,
|
||||
},
|
||||
{
|
||||
name: "waffo failed",
|
||||
tradeNo: "waffo-failed-guard",
|
||||
storedPaymentMethod: PaymentMethodStripe,
|
||||
expectedPaymentMethod: PaymentMethodWaffo,
|
||||
targetStatus: common.TopUpStatusFailed,
|
||||
name: "waffo failed",
|
||||
tradeNo: "waffo-failed-guard",
|
||||
storedPaymentProvider: PaymentProviderStripe,
|
||||
expectedPaymentProvider: PaymentProviderWaffo,
|
||||
targetStatus: common.TopUpStatusFailed,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -128,23 +130,23 @@ func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
truncateTables(t)
|
||||
insertUserForPaymentGuardTest(t, 150, 0)
|
||||
insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentMethod)
|
||||
insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentProvider)
|
||||
|
||||
err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentMethod, tc.targetStatus)
|
||||
err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentProvider, tc.targetStatus)
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
assert.Equal(t, common.TopUpStatusPending, getTopUpStatusForPaymentGuardTest(t, tc.tradeNo))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 202, 0)
|
||||
plan := insertSubscriptionPlanForPaymentGuardTest(t, 301)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentMethodStripe)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentProviderStripe)
|
||||
|
||||
err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, "alipay")
|
||||
err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, PaymentProviderEpay, "alipay")
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
|
||||
order := GetSubscriptionOrderByTradeNo("sub-guard-order")
|
||||
@@ -156,14 +158,14 @@ func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T)
|
||||
assert.Nil(t, topUp)
|
||||
}
|
||||
|
||||
func TestExpireSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestExpireSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 303, 0)
|
||||
plan := insertSubscriptionPlanForPaymentGuardTest(t, 401)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentMethodStripe)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentProviderStripe)
|
||||
|
||||
err := ExpireSubscriptionOrder("sub-expire-guard", PaymentMethodCreem)
|
||||
err := ExpireSubscriptionOrder("sub-expire-guard", PaymentProviderCreem)
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
|
||||
order := GetSubscriptionOrderByTradeNo("sub-expire-guard")
|
||||
|
||||
+15
-9
@@ -198,11 +198,12 @@ type SubscriptionOrder struct {
|
||||
PlanId int `json:"plan_id" gorm:"index"`
|
||||
Money float64 `json:"money"`
|
||||
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
Status string `json:"status"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"`
|
||||
Status string `json:"status"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
|
||||
ProviderPayload string `json:"provider_payload" gorm:"type:text"`
|
||||
}
|
||||
@@ -505,7 +506,9 @@ func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *Subscriptio
|
||||
}
|
||||
|
||||
// Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan.
|
||||
func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentMethod string) error {
|
||||
// expectedPaymentProvider guards against cross-gateway callback attacks (empty skips the check).
|
||||
// actualPaymentMethod updates the order's PaymentMethod to reflect the real payment type used (empty skips update).
|
||||
func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentProvider string, actualPaymentMethod string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("tradeNo is empty")
|
||||
}
|
||||
@@ -523,7 +526,7 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
|
||||
return ErrSubscriptionOrderNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if order.Status == common.TopUpStatusSuccess {
|
||||
@@ -552,6 +555,9 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
|
||||
if providerPayload != "" {
|
||||
order.ProviderPayload = providerPayload
|
||||
}
|
||||
if actualPaymentMethod != "" && order.PaymentMethod != actualPaymentMethod {
|
||||
order.PaymentMethod = actualPaymentMethod
|
||||
}
|
||||
if err := tx.Save(&order).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -610,7 +616,7 @@ func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error {
|
||||
return tx.Save(&topup).Error
|
||||
}
|
||||
|
||||
func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error {
|
||||
func ExpireSubscriptionOrder(tradeNo string, expectedPaymentProvider string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("tradeNo is empty")
|
||||
}
|
||||
@@ -623,7 +629,7 @@ func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
|
||||
return ErrSubscriptionOrderNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if order.Status != common.TopUpStatusPending {
|
||||
|
||||
+25
-16
@@ -12,15 +12,16 @@ import (
|
||||
)
|
||||
|
||||
type TopUp struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -30,6 +31,14 @@ const (
|
||||
PaymentMethodWaffoPancake = "waffo_pancake"
|
||||
)
|
||||
|
||||
const (
|
||||
PaymentProviderEpay = "epay"
|
||||
PaymentProviderStripe = "stripe"
|
||||
PaymentProviderCreem = "creem"
|
||||
PaymentProviderWaffo = "waffo"
|
||||
PaymentProviderWaffoPancake = "waffo_pancake"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPaymentMethodMismatch = errors.New("payment method mismatch")
|
||||
ErrTopUpNotFound = errors.New("topup not found")
|
||||
@@ -68,7 +77,7 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
|
||||
return topUp
|
||||
}
|
||||
|
||||
func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targetStatus string) error {
|
||||
func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentProvider string, targetStatus string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("未提供支付单号")
|
||||
}
|
||||
@@ -83,7 +92,7 @@ func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targ
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error; err != nil {
|
||||
return ErrTopUpNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && topUp.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && topUp.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
@@ -114,7 +123,7 @@ func Recharge(referenceId string, customerId string, callerIp string) (err error
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodStripe {
|
||||
if topUp.PaymentProvider != PaymentProviderStripe {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -340,7 +349,7 @@ func ManualCompleteTopUp(tradeNo string, callerIp string) error {
|
||||
// 计算应充值额度:
|
||||
// - Stripe 订单:Money 代表经分组倍率换算后的美元数量,直接 * QuotaPerUnit
|
||||
// - 其他订单(如易支付):Amount 为美元数量,* QuotaPerUnit
|
||||
if topUp.PaymentMethod == PaymentMethodStripe {
|
||||
if topUp.PaymentProvider == PaymentProviderStripe {
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
quotaToAdd = int(decimal.NewFromFloat(topUp.Money).Mul(dQuotaPerUnit).IntPart())
|
||||
} else {
|
||||
@@ -397,7 +406,7 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodCreem {
|
||||
if topUp.PaymentProvider != PaymentProviderCreem {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -472,7 +481,7 @@ func RechargeWaffo(tradeNo string, callerIp string) (err error) {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodWaffo {
|
||||
if topUp.PaymentProvider != PaymentProviderWaffo {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -535,7 +544,7 @@ func RechargeWaffoPancake(tradeNo string) (err error) {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodWaffoPancake {
|
||||
if topUp.PaymentProvider != PaymentProviderWaffoPancake {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user