mirror of
https://github.com/gitpod-io/gitpod.git
synced 2025-12-08 17:36:30 +00:00
458 lines
14 KiB
Go
458 lines
14 KiB
Go
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
|
|
// Licensed under the GNU Affero General Public License (AGPL).
|
|
// See License.AGPL.txt in the project root for license information.
|
|
|
|
package stripe
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gitpod-io/gitpod/common-go/log"
|
|
db "github.com/gitpod-io/gitpod/components/gitpod-db/go"
|
|
"github.com/stripe/stripe-go/v72"
|
|
"github.com/stripe/stripe-go/v72/client"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
const (
|
|
AttributionIDMetadataKey = "attributionId"
|
|
PreferredCurrencyMetadataKey = "preferredCurrency"
|
|
)
|
|
|
|
type Client struct {
|
|
sc *client.API
|
|
}
|
|
|
|
type ClientConfig struct {
|
|
PublishableKey string `json:"publishableKey"`
|
|
SecretKey string `json:"secretKey"`
|
|
}
|
|
|
|
type PriceConfig struct {
|
|
EUR string `json:"eur"`
|
|
USD string `json:"usd"`
|
|
}
|
|
|
|
type StripePrices struct {
|
|
IndividualUsagePriceIDs PriceConfig `json:"individualUsagePriceIds"`
|
|
TeamUsagePriceIDs PriceConfig `json:"teamUsagePriceIds"`
|
|
}
|
|
|
|
func ReadConfigFromFile(path string) (ClientConfig, error) {
|
|
bytes, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return ClientConfig{}, fmt.Errorf("failed to read stripe client config: %w", err)
|
|
}
|
|
|
|
var config ClientConfig
|
|
err = json.Unmarshal(bytes, &config)
|
|
if err != nil {
|
|
return ClientConfig{}, fmt.Errorf("failed to unmarshal Stripe Client config: %w", err)
|
|
}
|
|
|
|
return config, nil
|
|
}
|
|
|
|
// New authenticates a Stripe client using the provided config
|
|
func New(config ClientConfig) (*Client, error) {
|
|
return NewWithHTTPClient(config, &http.Client{
|
|
Transport: http.DefaultTransport,
|
|
Timeout: 10 * time.Second,
|
|
})
|
|
}
|
|
|
|
func NewWithHTTPClient(config ClientConfig, c *http.Client) (*Client, error) {
|
|
sc := &client.API{}
|
|
|
|
sc.Init(config.SecretKey, stripe.NewBackends(&http.Client{
|
|
Transport: http.DefaultTransport,
|
|
Timeout: 10 * time.Second,
|
|
}))
|
|
|
|
return &Client{sc: sc}, nil
|
|
}
|
|
|
|
type UsageRecord struct {
|
|
SubscriptionItemID string
|
|
Quantity int64
|
|
}
|
|
|
|
type Invoice struct {
|
|
ID string
|
|
SubscriptionID string
|
|
Amount int64
|
|
Currency string
|
|
Credits int64
|
|
}
|
|
|
|
// UpdateUsage updates teams' Stripe subscriptions with usage data
|
|
// `usageForTeam` is a map from team name to total workspace seconds used within a billing period.
|
|
func (c *Client) UpdateUsage(ctx context.Context, creditsPerAttributionID map[db.AttributionID]int64) error {
|
|
queries := queriesForCustomersWithAttributionIDs(creditsPerAttributionID)
|
|
|
|
for _, query := range queries {
|
|
logger := log.WithField("stripe_query", query)
|
|
|
|
customers, err := c.findCustomers(ctx, query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to find customers: %w", err)
|
|
}
|
|
|
|
for _, customer := range customers {
|
|
attributionIDRaw := customer.Metadata[AttributionIDMetadataKey]
|
|
logger.Infof("Found customer %q for attribution ID %q", customer.Name, attributionIDRaw)
|
|
|
|
attributionID, err := db.ParseAttributionID(attributionIDRaw)
|
|
if err != nil {
|
|
logger.WithError(err).Error("Failed to parse attribution ID from Stripe metadata.")
|
|
continue
|
|
}
|
|
|
|
lgr := logger.
|
|
WithField("customer_id", customer.ID).
|
|
WithField("customer_name", customer.Name).
|
|
WithField("subscriptions", customer.Subscriptions).
|
|
WithField("attribution_id", attributionID)
|
|
|
|
err = c.updateUsageForCustomer(ctx, customer, creditsPerAttributionID[attributionID])
|
|
if err != nil {
|
|
lgr.WithError(err).Errorf("Failed to update usage.")
|
|
}
|
|
reportStripeUsageUpdate(err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) findCustomers(ctx context.Context, query string) (customers []*stripe.Customer, err error) {
|
|
now := time.Now()
|
|
reportStripeRequestStarted("customers_search")
|
|
defer func() {
|
|
reportStripeRequestCompleted("customers_search", err, time.Since(now))
|
|
}()
|
|
|
|
params := &stripe.CustomerSearchParams{
|
|
SearchParams: stripe.SearchParams{
|
|
Query: query,
|
|
Expand: []*string{stripe.String("data.subscriptions")},
|
|
Context: ctx,
|
|
},
|
|
}
|
|
iter := c.sc.Customers.Search(params)
|
|
|
|
for iter.Next() {
|
|
customers = append(customers, iter.Customer())
|
|
}
|
|
if iter.Err() != nil {
|
|
return nil, fmt.Errorf("failed to search for customers: %w", iter.Err())
|
|
}
|
|
|
|
return customers, nil
|
|
}
|
|
|
|
func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Customer, credits int64) (err error) {
|
|
logger := log.
|
|
WithField("customer_id", customer.ID).
|
|
WithField("customer_name", customer.Name).
|
|
WithField("credits", credits)
|
|
if credits < 0 {
|
|
logger.Infof("Received request to update customer %s usage to negative value, updating to 0 instead.", customer.ID)
|
|
|
|
// nullify any existing usage, but do not set it to negative value - negative invoice doesn't make sense...
|
|
credits = 0
|
|
}
|
|
|
|
subscriptions := customer.Subscriptions.Data
|
|
if len(subscriptions) == 0 {
|
|
logger.Info("Customer does not have a valid (one) subscription. This happens when a customer has cancelled their subscription but still exist as a Customer in stripe.")
|
|
return nil
|
|
}
|
|
if len(subscriptions) > 1 {
|
|
return fmt.Errorf("customer has more than 1 subscription (got: %d), this is a consistency error and requires a manual intervention", len(subscriptions))
|
|
}
|
|
|
|
subscription := customer.Subscriptions.Data[0]
|
|
|
|
log.Infof("Customer has subscription: %q", subscription.ID)
|
|
if len(subscription.Items.Data) != 1 {
|
|
return fmt.Errorf("subscription %s has an unexpected number of subscriptionItems (expected 1, got %d)", subscription.ID, len(subscription.Items.Data))
|
|
}
|
|
|
|
subscriptionItemId := subscription.Items.Data[0].ID
|
|
log.Infof("Registering usage against subscriptionItem %q", subscriptionItemId)
|
|
|
|
reportStripeRequestStarted("usage_record_update")
|
|
now := time.Now()
|
|
defer func() {
|
|
reportStripeRequestCompleted("usage_record_update", err, time.Since(now))
|
|
}()
|
|
_, err = c.sc.UsageRecords.New(&stripe.UsageRecordParams{
|
|
Params: stripe.Params{
|
|
Context: ctx,
|
|
},
|
|
SubscriptionItem: stripe.String(subscriptionItemId),
|
|
Quantity: stripe.Int64(credits),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to register usage for customer %q on subscription item %s", customer.Name, subscriptionItemId)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) GetCustomerByAttributionID(ctx context.Context, attributionID string) (*stripe.Customer, error) {
|
|
customers, err := c.findCustomers(ctx, fmt.Sprintf("metadata['attributionId']:'%s'", attributionID))
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "failed to find customers: %v", err)
|
|
}
|
|
|
|
if len(customers) == 0 {
|
|
return nil, status.Errorf(codes.NotFound, "no customer found for attribution_id: %s", attributionID)
|
|
}
|
|
if len(customers) > 1 {
|
|
return nil, status.Errorf(codes.FailedPrecondition, "found multiple customers for attributiuon_id: %s", attributionID)
|
|
}
|
|
|
|
return customers[0], nil
|
|
}
|
|
|
|
func (c *Client) GetCustomer(ctx context.Context, customerID string) (customer *stripe.Customer, err error) {
|
|
now := time.Now()
|
|
reportStripeRequestStarted("customer_get")
|
|
defer func() {
|
|
reportStripeRequestCompleted("customer_get", err, time.Since(now))
|
|
}()
|
|
|
|
customer, err = c.sc.Customers.Get(customerID, &stripe.CustomerParams{
|
|
Params: stripe.Params{
|
|
Context: ctx,
|
|
Expand: []*string{stripe.String("tax"), stripe.String("subscriptions")},
|
|
},
|
|
})
|
|
if err != nil {
|
|
if stripeErr, ok := err.(*stripe.Error); ok {
|
|
switch stripeErr.Code {
|
|
case stripe.ErrorCodeMissing:
|
|
return nil, status.Errorf(codes.NotFound, "customer %s does not exist in stripe", customerID)
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to get customer by customer ID %s", customerID)
|
|
}
|
|
|
|
return customer, nil
|
|
}
|
|
|
|
func (c *Client) GetPriceInformation(ctx context.Context, priceID string) (price *stripe.Price, err error) {
|
|
now := time.Now()
|
|
reportStripeRequestStarted("prices_get")
|
|
defer func() {
|
|
reportStripeRequestCompleted("prices_get", err, time.Since(now))
|
|
}()
|
|
|
|
price, err = c.sc.Prices.Get(priceID, &stripe.PriceParams{
|
|
Params: stripe.Params{
|
|
Context: ctx,
|
|
},
|
|
})
|
|
if err != nil {
|
|
if stripeErr, ok := err.(*stripe.Error); ok {
|
|
switch stripeErr.Code {
|
|
case stripe.ErrorCodeMissing:
|
|
return nil, status.Errorf(codes.NotFound, "price %s does not exist in stripe", priceID)
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to get price by price ID %s", priceID)
|
|
}
|
|
|
|
return price, nil
|
|
}
|
|
|
|
type CreateCustomerParams struct {
|
|
AttributuonID string
|
|
Currency string
|
|
Email string
|
|
Name string
|
|
}
|
|
|
|
func (c *Client) CreateCustomer(ctx context.Context, params CreateCustomerParams) (customer *stripe.Customer, err error) {
|
|
now := time.Now()
|
|
reportStripeRequestStarted("customers_create")
|
|
defer func() {
|
|
reportStripeRequestCompleted("customers_create", err, time.Since(now))
|
|
}()
|
|
|
|
customer, err = c.sc.Customers.New(&stripe.CustomerParams{
|
|
Params: stripe.Params{
|
|
Context: ctx,
|
|
Metadata: map[string]string{
|
|
// We set the preferred currency on the metadata such that we can later retreive it when we're creating a Subscription
|
|
// This is also done to propagate the preference into the Customer such that we can inform them when their
|
|
// new subscription would use a different currency to the previous one
|
|
PreferredCurrencyMetadataKey: params.Currency,
|
|
AttributionIDMetadataKey: params.AttributuonID,
|
|
},
|
|
},
|
|
Email: stripe.String(params.Email),
|
|
Name: stripe.String(params.Name),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Stripe customer: %w", err)
|
|
}
|
|
return customer, nil
|
|
}
|
|
|
|
func (c *Client) GetInvoiceWithCustomer(ctx context.Context, invoiceID string) (invoice *stripe.Invoice, err error) {
|
|
if invoiceID == "" {
|
|
return nil, fmt.Errorf("no invoice ID specified")
|
|
}
|
|
|
|
now := time.Now()
|
|
reportStripeRequestStarted("invoice_get")
|
|
defer func() {
|
|
reportStripeRequestCompleted("invoice_get", err, time.Since(now))
|
|
}()
|
|
|
|
invoice, err = c.sc.Invoices.Get(invoiceID, &stripe.InvoiceParams{
|
|
Params: stripe.Params{
|
|
Context: ctx,
|
|
Expand: []*string{stripe.String("customer")},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get invoice %s: %w", invoiceID, err)
|
|
}
|
|
return invoice, nil
|
|
}
|
|
|
|
func (c *Client) GetSubscriptionWithCustomer(ctx context.Context, subscriptionID string) (subscription *stripe.Subscription, err error) {
|
|
if subscriptionID == "" {
|
|
return nil, fmt.Errorf("no subscriptionID specified")
|
|
}
|
|
|
|
now := time.Now()
|
|
reportStripeRequestStarted("subscription_get")
|
|
defer func() {
|
|
reportStripeRequestCompleted("subscription_get", err, time.Since(now))
|
|
}()
|
|
|
|
subscription, err = c.sc.Subscriptions.Get(subscriptionID, &stripe.SubscriptionParams{
|
|
Params: stripe.Params{
|
|
Expand: []*string{stripe.String("customer")},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get subscription %s: %w", subscriptionID, err)
|
|
}
|
|
return subscription, nil
|
|
}
|
|
|
|
func (c *Client) CreateSubscription(ctx context.Context, customerID string, priceID string, isAutomaticTaxSupported bool) (*stripe.Subscription, error) {
|
|
if customerID == "" {
|
|
return nil, fmt.Errorf("no customerID specified")
|
|
}
|
|
if priceID == "" {
|
|
return nil, fmt.Errorf("no priceID specified")
|
|
}
|
|
|
|
params := &stripe.SubscriptionParams{
|
|
Customer: stripe.String(customerID),
|
|
Items: []*stripe.SubscriptionItemsParams{
|
|
{
|
|
Price: stripe.String(priceID),
|
|
},
|
|
},
|
|
AutomaticTax: &stripe.SubscriptionAutomaticTaxParams{
|
|
Enabled: stripe.Bool(isAutomaticTaxSupported),
|
|
},
|
|
}
|
|
|
|
subscription, err := c.sc.Subscriptions.New(params)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get subscription with customer ID %s", customerID)
|
|
}
|
|
|
|
return subscription, err
|
|
}
|
|
|
|
func (c *Client) SetDefaultPaymentForCustomer(ctx context.Context, customerID string, setupIntentId string) (*stripe.Customer, error) {
|
|
if customerID == "" {
|
|
return nil, fmt.Errorf("no customerID specified")
|
|
}
|
|
|
|
if setupIntentId == "" {
|
|
return nil, fmt.Errorf("no setupIntentID specified")
|
|
}
|
|
|
|
setupIntent, err := c.sc.SetupIntents.Get(setupIntentId, &stripe.SetupIntentParams{
|
|
Params: stripe.Params{
|
|
Context: ctx,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Failed to retrieve setup intent with id %s", setupIntentId)
|
|
}
|
|
|
|
paymentMethod, err := c.sc.PaymentMethods.Attach(setupIntent.PaymentMethod.ID, &stripe.PaymentMethodAttachParams{Customer: &customerID})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Failed to attach payment method to setup intent ID %s", setupIntentId)
|
|
}
|
|
|
|
customer, _ := c.sc.Customers.Update(customerID, &stripe.CustomerParams{
|
|
InvoiceSettings: &stripe.CustomerInvoiceSettingsParams{
|
|
DefaultPaymentMethod: stripe.String(paymentMethod.ID)},
|
|
Address: &stripe.AddressParams{
|
|
Line1: stripe.String(paymentMethod.BillingDetails.Address.Line1),
|
|
Country: stripe.String(paymentMethod.BillingDetails.Address.Country)}})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Failed to update customer with id %s", customerID)
|
|
}
|
|
|
|
return customer, nil
|
|
}
|
|
|
|
func GetAttributionID(ctx context.Context, customer *stripe.Customer) (db.AttributionID, error) {
|
|
if customer == nil {
|
|
log.Error("No customer information available for invoice.")
|
|
return "", status.Errorf(codes.Internal, "Failed to retrieve customer details from invoice.")
|
|
}
|
|
return db.ParseAttributionID(customer.Metadata[AttributionIDMetadataKey])
|
|
}
|
|
|
|
// queriesForCustomersWithAttributionIDs constructs Stripe query strings to find the Stripe Customer for each teamId
|
|
// It returns multiple queries, each being a big disjunction of subclauses so that we can process multiple teamIds in one query.
|
|
// `clausesPerQuery` is a limit enforced by the Stripe API.
|
|
func queriesForCustomersWithAttributionIDs(creditsByAttributionID map[db.AttributionID]int64) []string {
|
|
attributionIDs := make([]string, 0, len(creditsByAttributionID))
|
|
for k := range creditsByAttributionID {
|
|
attributionIDs = append(attributionIDs, string(k))
|
|
}
|
|
sort.Strings(attributionIDs)
|
|
|
|
const clausesPerQuery = 10
|
|
var queries []string
|
|
sb := strings.Builder{}
|
|
|
|
for i := 0; i < len(attributionIDs); i += clausesPerQuery {
|
|
sb.Reset()
|
|
for j := 0; j < clausesPerQuery && i+j < len(attributionIDs); j++ {
|
|
sb.WriteString(fmt.Sprintf("metadata['%s']:'%s'", AttributionIDMetadataKey, attributionIDs[i+j]))
|
|
if j < clausesPerQuery-1 && i+j < len(attributionIDs)-1 {
|
|
sb.WriteString(" OR ")
|
|
}
|
|
}
|
|
queries = append(queries, sb.String())
|
|
}
|
|
|
|
return queries
|
|
}
|