Sven Efftinge 2fd43d6307 [usage] update stripe usage on finalization
update stripe with current usage immediately, so that
invoices created between now and the next reconcile
are correct.
2023-02-10 14:33:15 +01:00

510 lines
18 KiB
Go

// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.
package apiv1
import (
"context"
"errors"
"fmt"
"math"
"time"
"github.com/gitpod-io/gitpod/common-go/log"
db "github.com/gitpod-io/gitpod/components/gitpod-db/go"
v1 "github.com/gitpod-io/gitpod/usage-api/v1"
"github.com/gitpod-io/gitpod/usage/pkg/stripe"
"github.com/google/uuid"
stripe_api "github.com/stripe/stripe-go/v72"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"
)
func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB, ccManager *db.CostCenterManager, stripePrices stripe.StripePrices) *BillingService {
return &BillingService{
stripeClient: stripeClient,
conn: conn,
ccManager: ccManager,
stripePrices: stripePrices,
}
}
type BillingService struct {
conn *gorm.DB
stripeClient *stripe.Client
ccManager *db.CostCenterManager
stripePrices stripe.StripePrices
v1.UnimplementedBillingServiceServer
}
func (s *BillingService) GetStripeCustomer(ctx context.Context, req *v1.GetStripeCustomerRequest) (*v1.GetStripeCustomerResponse, error) {
storeStripeCustomerAndRespond := func(ctx context.Context, cus *stripe_api.Customer, attributionID db.AttributionID) (*v1.GetStripeCustomerResponse, error) {
logger := log.WithField("stripe_customer_id", cus.ID).WithField("attribution_id", attributionID)
// Store it in the DB such that subsequent lookups don't need to go to Stripe
result, err := s.storeStripeCustomer(ctx, cus, attributionID)
if err != nil {
logger.WithError(err).Error("Failed to store stripe customer in the database.")
// Storing failed, but we don't want to block the caller since we do have the data, return it as a success
return &v1.GetStripeCustomerResponse{
Customer: convertStripeCustomer(cus),
}, nil
}
return &v1.GetStripeCustomerResponse{
Customer: result,
}, nil
}
switch identifier := req.GetIdentifier().(type) {
case *v1.GetStripeCustomerRequest_AttributionId:
attributionID, err := db.ParseAttributionID(identifier.AttributionId)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "Invalid attribution ID %s", attributionID)
}
logger := log.WithField("attribution_id", attributionID)
customer, err := db.GetStripeCustomerByAttributionID(ctx, s.conn, attributionID)
if err != nil {
// We don't yet have it in the DB
if errors.Is(err, db.ErrorNotFound) {
stripeCustomer, err := s.stripeClient.GetCustomerByAttributionID(ctx, string(attributionID))
if err != nil {
return nil, err
}
return storeStripeCustomerAndRespond(ctx, stripeCustomer, attributionID)
}
logger.WithError(err).Error("Failed to lookup stripe customer from DB")
}
return &v1.GetStripeCustomerResponse{
Customer: convertDBStripeCustomerToResponse(customer),
}, nil
case *v1.GetStripeCustomerRequest_StripeCustomerId:
if identifier.StripeCustomerId == "" {
return nil, status.Errorf(codes.InvalidArgument, "empty stripe customer ID supplied")
}
logger := log.WithField("stripe_customer_id", identifier.StripeCustomerId)
customer, err := db.GetStripeCustomer(ctx, s.conn, identifier.StripeCustomerId)
if err != nil {
// We don't yet have it in the DB
if errors.Is(err, db.ErrorNotFound) {
stripeCustomer, err := s.stripeClient.GetCustomer(ctx, identifier.StripeCustomerId)
if err != nil {
return nil, err
}
attributionID, err := stripe.GetAttributionID(ctx, stripeCustomer)
if err != nil {
return nil, status.Errorf(codes.Internal, "Failed to parse attribution ID from Stripe customer %s", stripeCustomer.ID)
}
return storeStripeCustomerAndRespond(ctx, stripeCustomer, attributionID)
}
logger.WithError(err).Error("Failed to lookup stripe customer from DB")
return nil, status.Errorf(codes.NotFound, "Failed to lookup stripe customer from DB: %s", err.Error())
}
return &v1.GetStripeCustomerResponse{
Customer: convertDBStripeCustomerToResponse(customer),
}, nil
default:
return nil, status.Errorf(codes.InvalidArgument, "Unknown identifier")
}
}
func (s *BillingService) CreateStripeCustomer(ctx context.Context, req *v1.CreateStripeCustomerRequest) (*v1.CreateStripeCustomerResponse, error) {
attributionID, err := db.ParseAttributionID(req.GetAttributionId())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "Invalid attribution ID %s", attributionID)
}
if req.GetCurrency() == "" {
return nil, status.Error(codes.InvalidArgument, "Invalid currency specified")
}
if req.GetEmail() == "" {
return nil, status.Error(codes.InvalidArgument, "Invalid email specified")
}
if req.GetName() == "" {
return nil, status.Error(codes.InvalidArgument, "Invalid name specified")
}
customer, err := s.stripeClient.CreateCustomer(ctx, stripe.CreateCustomerParams{
AttributuonID: string(attributionID),
Currency: req.GetCurrency(),
Email: req.GetEmail(),
Name: req.GetName(),
})
if err != nil {
log.WithError(err).Errorf("Failed to create stripe customer.")
return nil, status.Errorf(codes.Internal, "Failed to create stripe customer")
}
err = db.CreateStripeCustomer(ctx, s.conn, db.StripeCustomer{
StripeCustomerID: customer.ID,
AttributionID: attributionID,
CreationTime: db.NewVarCharTime(time.Unix(customer.Created, 0)),
Currency: req.GetCurrency(),
})
if err != nil {
log.WithField("attribution_id", attributionID).WithField("stripe_customer_id", customer.ID).WithError(err).Error("Failed to store Stripe Customer in the database.")
// We do not return an error to the caller here, as we did manage to create the stripe customer in Stripe and we can proceed with other flows
// The StripeCustomer will be backfilled in the DB on the next GetStripeCustomer call by doing a search.
}
return &v1.CreateStripeCustomerResponse{
Customer: convertStripeCustomer(customer),
}, nil
}
func (s *BillingService) CreateStripeSubscription(ctx context.Context, req *v1.CreateStripeSubscriptionRequest) (*v1.CreateStripeSubscriptionResponse, error) {
attributionID, err := db.ParseAttributionID(req.GetAttributionId())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "Invalid attribution ID %s", attributionID)
}
customer, err := s.GetStripeCustomer(ctx, &v1.GetStripeCustomerRequest{
Identifier: &v1.GetStripeCustomerRequest_AttributionId{
AttributionId: string(attributionID),
},
})
if err != nil {
return nil, err
}
_, err = s.stripeClient.SetDefaultPaymentForCustomer(ctx, customer.Customer.Id, req.SetupIntentId)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "Failed to set default payment for customer ID %s", customer.Customer.Id)
}
stripeCustomer, err := s.stripeClient.GetCustomer(ctx, customer.Customer.Id)
if err != nil {
return nil, err
}
// if the customer has a subscription, return error
for _, subscription := range stripeCustomer.Subscriptions.Data {
if subscription.Status != "canceled" {
return nil, status.Errorf(codes.AlreadyExists, "Customer (%s) already has an active subscription (%s)", attributionID, subscription.ID)
}
}
priceID, err := getPriceIdentifier(attributionID, stripeCustomer, s)
if err != nil {
return nil, err
}
var isAutomaticTaxSupported bool
if stripeCustomer.Tax != nil {
isAutomaticTaxSupported = stripeCustomer.Tax.AutomaticTax == "supported"
}
if !isAutomaticTaxSupported {
log.Warnf("Automatic Stripe tax is not supported for customer %s", stripeCustomer.ID)
}
subscription, err := s.stripeClient.CreateSubscription(ctx, stripeCustomer.ID, priceID, isAutomaticTaxSupported)
if err != nil {
return nil, status.Errorf(codes.Internal, "Failed to create subscription with customer ID %s", customer.Customer.Id)
}
return &v1.CreateStripeSubscriptionResponse{
Subscription: &v1.StripeSubscription{
Id: subscription.ID,
},
}, nil
}
func getPriceIdentifier(attributionID db.AttributionID, stripeCustomer *stripe_api.Customer, s *BillingService) (string, error) {
preferredCurrency := stripeCustomer.Metadata["preferredCurrency"]
if stripeCustomer.Metadata["preferredCurrency"] == "" {
log.
WithField("stripe_customer_id", stripeCustomer.ID).
Warn("No preferred currency set. Defaulting to USD")
}
entity, _ := attributionID.Values()
switch entity {
case db.AttributionEntity_User:
switch preferredCurrency {
case "EUR":
return s.stripePrices.IndividualUsagePriceIDs.EUR, nil
default:
return s.stripePrices.IndividualUsagePriceIDs.USD, nil
}
case db.AttributionEntity_Team:
switch preferredCurrency {
case "EUR":
return s.stripePrices.TeamUsagePriceIDs.EUR, nil
default:
return s.stripePrices.TeamUsagePriceIDs.USD, nil
}
default:
return "", status.Errorf(codes.InvalidArgument, "Invalid currency %s for customer ID %s", stripeCustomer.Metadata["preferredCurrency"], stripeCustomer.ID)
}
}
func (s *BillingService) ReconcileInvoices(ctx context.Context, in *v1.ReconcileInvoicesRequest) (*v1.ReconcileInvoicesResponse, error) {
balances, err := db.ListBalance(ctx, s.conn)
if err != nil {
log.WithError(err).Errorf("Failed to reconcile invoices.")
return nil, status.Errorf(codes.Internal, "Failed to reconcile invoices.")
}
stripeBalances, err := balancesForStripeCostCenters(ctx, s.ccManager, balances)
if err != nil {
return nil, status.Errorf(codes.Internal, "Failed to identify stripe balances.")
}
creditCentsByAttribution := map[db.AttributionID]int64{}
for _, balance := range stripeBalances {
creditCentsByAttribution[balance.AttributionID] = int64(math.Ceil(balance.CreditCents.ToCredits()))
}
err = s.stripeClient.UpdateUsage(ctx, creditCentsByAttribution)
if err != nil {
log.WithError(err).Errorf("Failed to udpate usage in stripe.")
return nil, status.Errorf(codes.Internal, "Failed to update usage in stripe")
}
return &v1.ReconcileInvoicesResponse{}, nil
}
func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInvoiceRequest) (*v1.FinalizeInvoiceResponse, error) {
logger := log.WithField("invoice_id", in.GetInvoiceId())
if in.GetInvoiceId() == "" {
return nil, status.Errorf(codes.InvalidArgument, "Missing InvoiceID")
}
invoice, err := s.stripeClient.GetInvoiceWithCustomer(ctx, in.GetInvoiceId())
if err != nil {
logger.WithError(err).Error("Failed to retrieve invoice from Stripe.")
return nil, status.Errorf(codes.NotFound, "Failed to get invoice with ID %s: %s", in.GetInvoiceId(), err.Error())
}
usage, err := InternalComputeInvoiceUsage(ctx, invoice)
if err != nil {
return nil, err
}
err = db.InsertUsage(ctx, s.conn, usage)
if err != nil {
logger.WithError(err).Errorf("Failed to insert Invoice usage record into the db.")
return nil, status.Errorf(codes.Internal, "Failed to insert Invoice into usage records.")
}
logger.WithField("usage_id", usage.ID).Infof("Inserted usage record into database for %f credits against %s attribution", usage.CreditCents.ToCredits(), usage.AttributionID)
_, err = s.ccManager.IncrementBillingCycle(ctx, usage.AttributionID)
if err != nil {
// we are just logging at this point, so that we don't see the event again as the usage has been recorded.
logger.WithError(err).Errorf("Failed to increment billing cycle.")
}
// update stripe with current usage immediately, so that invoices created between now and the next reconcile are correct.
newBalance, err := db.GetBalance(ctx, s.conn, usage.AttributionID)
if err != nil {
// we are just logging at this point, so that we don't see the event again as the usage has been recorded.
logger.WithError(err).Errorf("Failed to compute new balance.")
return &v1.FinalizeInvoiceResponse{}, nil
}
err = s.stripeClient.UpdateUsage(ctx, map[db.AttributionID]int64{
usage.AttributionID: int64(math.Ceil(newBalance.ToCredits())),
})
if err != nil {
// we are just logging at this point, so that we don't see the event again as the usage has been recorded.
log.WithError(err).Errorf("Failed to udpate usage in stripe after receiving invoive.finalized.")
}
return &v1.FinalizeInvoiceResponse{}, nil
}
func InternalComputeInvoiceUsage(ctx context.Context, invoice *stripe_api.Invoice) (db.Usage, error) {
logger := log.WithField("invoice_id", invoice.ID)
attributionID, err := stripe.GetAttributionID(ctx, invoice.Customer)
if err != nil {
return db.Usage{}, err
}
logger = logger.WithField("attributionID", attributionID)
finalizedAt := time.Unix(invoice.StatusTransitions.FinalizedAt, 0)
logger = logger.
WithField("attribution_id", attributionID).
WithField("invoice_finalized_at", finalizedAt)
if invoice.Lines == nil || len(invoice.Lines.Data) == 0 {
logger.Errorf("Invoice %s did not contain any lines so we cannot extract quantity to reflect it in usage.", invoice.ID)
return db.Usage{}, status.Errorf(codes.Internal, "Invoice did not contain any lines.")
}
lines := invoice.Lines.Data
var creditsOnInvoice int64
for _, line := range lines {
creditsOnInvoice += line.Quantity
}
return db.Usage{
ID: uuid.New(),
AttributionID: attributionID,
Description: fmt.Sprintf("Invoice %s finalized in Stripe", invoice.ID),
// Apply negative value of credits to reduce accrued credit usage
CreditCents: db.NewCreditCents(float64(-creditsOnInvoice)),
EffectiveTime: db.NewVarCharTime(finalizedAt),
Kind: db.InvoiceUsageKind,
Draft: false,
Metadata: nil,
}, nil
}
func (s *BillingService) CancelSubscription(ctx context.Context, in *v1.CancelSubscriptionRequest) (*v1.CancelSubscriptionResponse, error) {
logger := log.WithField("subscription_id", in.GetSubscriptionId())
logger.Infof("Subscription ended. Setting cost center back to free.")
if in.GetSubscriptionId() == "" {
return nil, status.Errorf(codes.InvalidArgument, "subscriptionId is required")
}
subscription, err := s.stripeClient.GetSubscriptionWithCustomer(ctx, in.GetSubscriptionId())
if err != nil {
return nil, err
}
attributionID, err := stripe.GetAttributionID(ctx, subscription.Customer)
if err != nil {
return nil, err
}
costCenter, err := s.ccManager.GetOrCreateCostCenter(ctx, attributionID)
if err != nil {
return nil, err
}
costCenter.BillingStrategy = db.CostCenter_Other
_, err = s.ccManager.UpdateCostCenter(ctx, costCenter)
if err != nil {
return nil, err
}
return &v1.CancelSubscriptionResponse{}, nil
}
func (s *BillingService) getPriceId(ctx context.Context, attributionId string) string {
defaultPriceId := s.stripePrices.TeamUsagePriceIDs.USD
attributionID, err := db.ParseAttributionID(attributionId)
if err != nil {
log.Errorf("Failed to parse attribution ID %s: %s", attributionId, err.Error())
return defaultPriceId
}
customer, err := s.GetStripeCustomer(ctx, &v1.GetStripeCustomerRequest{
Identifier: &v1.GetStripeCustomerRequest_AttributionId{
AttributionId: string(attributionID),
},
})
if err != nil {
if status.Code(err) != codes.NotFound {
log.Errorf("Failed to get stripe customer for attribution ID %s: %s", attributionId, err.Error())
}
return defaultPriceId
}
stripeCustomer, err := s.stripeClient.GetCustomer(ctx, customer.Customer.Id)
if err != nil {
log.Errorf("Failed to get customer infromation from stripe for customer ID %s: %s", customer.Customer.Id, err.Error())
return defaultPriceId
}
// if the customer has an active subscription, return that information
for _, subscription := range stripeCustomer.Subscriptions.Data {
if subscription.Status != "canceled" {
return subscription.Plan.ID
}
}
priceID, err := getPriceIdentifier(attributionID, stripeCustomer, s)
if err != nil {
log.Errorf("Failed to get price identifier for attribution ID %s: %s", attributionId, err.Error())
return defaultPriceId
}
return priceID
}
func (s *BillingService) GetPriceInformation(ctx context.Context, req *v1.GetPriceInformationRequest) (*v1.GetPriceInformationResponse, error) {
_, err := db.ParseAttributionID(req.GetAttributionId())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "Invalid attribution ID %s", req.GetAttributionId())
}
priceID := s.getPriceId(ctx, req.GetAttributionId())
price, err := s.stripeClient.GetPriceInformation(ctx, priceID)
if err != nil {
return nil, err
}
information := price.Metadata["human_readable_description"]
if information == "" {
information = "No information available"
}
return &v1.GetPriceInformationResponse{
HumanReadableDescription: information,
}, nil
}
func (s *BillingService) storeStripeCustomer(ctx context.Context, cus *stripe_api.Customer, attributionID db.AttributionID) (*v1.StripeCustomer, error) {
err := db.CreateStripeCustomer(ctx, s.conn, db.StripeCustomer{
StripeCustomerID: cus.ID,
AttributionID: attributionID,
Currency: cus.Metadata[stripe.PreferredCurrencyMetadataKey],
// We use the original Stripe supplied creation timestamp, this ensures that we stay true to our ordering of customer creation records.
CreationTime: db.NewVarCharTime(time.Unix(cus.Created, 0)),
})
if err != nil {
return nil, err
}
return &v1.StripeCustomer{
Id: cus.ID,
Currency: cus.Metadata[stripe.PreferredCurrencyMetadataKey],
}, nil
}
func balancesForStripeCostCenters(ctx context.Context, cm *db.CostCenterManager, balances []db.Balance) ([]db.Balance, error) {
var result []db.Balance
for _, balance := range balances {
// filter out balances for non-stripe attribution IDs
costCenter, err := cm.GetOrCreateCostCenter(ctx, balance.AttributionID)
if err != nil {
return nil, err
}
// We only update Stripe usage when the AttributionID is billed against Stripe (determined through CostCenter)
if costCenter.BillingStrategy != db.CostCenter_Stripe {
continue
}
result = append(result, balance)
}
return result, nil
}
func convertStripeCustomer(customer *stripe_api.Customer) *v1.StripeCustomer {
return &v1.StripeCustomer{
Id: customer.ID,
Currency: customer.Metadata[stripe.PreferredCurrencyMetadataKey],
}
}
func convertDBStripeCustomerToResponse(cus db.StripeCustomer) *v1.StripeCustomer {
return &v1.StripeCustomer{
Id: cus.StripeCustomerID,
Currency: cus.Currency,
}
}