gitpod/dev/sweeper/main.go
2021-10-29 16:00:07 +02:00

231 lines
5.9 KiB
Go

// Copyright (c) 2020 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License-AGPL.txt in the project root for license information.
package main
import (
"context"
"database/sql"
"fmt"
"log"
"net/http"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"
"time"
_ "github.com/go-sql-driver/mysql"
"github.com/google/go-github/v39/github"
flag "github.com/spf13/pflag"
"golang.org/x/oauth2"
)
var (
// activity check
dbUser string
dbPass string
dbHost string
dbName string
timeout time.Duration
// branch exists check
owner string
repo string
branch string
tokenEnvVarName string
// common
command string
period time.Duration
readyEndpointAddr string
)
func init() {
// activity check
flag.StringVar(&dbHost, "db-host", "db:3306", "database hostname")
flag.StringVar(&dbName, "db-name", "gitpod", "database name")
flag.StringVar(&dbUser, "db-user", "root", "database username")
flag.StringVar(&dbPass, "db-pass", "test", "database password")
flag.DurationVar(&timeout, "timeout", 4*time.Hour, "time until the dev-staging installation is removed - must be a valid duration")
// branch exists check
flag.StringVar(&owner, "owner", "", "the owner of the repo this preview env is associated with")
flag.StringVar(&repo, "repo", "", "the repo this preview env is associated with")
flag.StringVar(&branch, "branch", "", "the branch this preview env is associated with")
flag.StringVar(&tokenEnvVarName, "tokenEnvVarName", "", "the name of the environment variable containing the GH token")
// common
flag.DurationVar(&period, "period", 1*time.Minute, "time between checks - must be a valid duration")
flag.StringVarP(&command, "command", "c", "echo time is up", "command to execute once we've timed out")
flag.StringVar(&readyEndpointAddr, "ready-endpoint-addr", ":8080", "address where to serve the Kubernetes ready endpoint")
}
func main() {
flag.Parse()
log.Printf("sweeper started")
go func() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "ready")
})
log.Printf("serving ready endpoint on %s", readyEndpointAddr)
log.Fatal(http.ListenAndServe(readyEndpointAddr, nil))
}()
// start checks
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if branch != "" {
log.Printf("starting branch check")
checkIfBranchExists(ctx)
}
log.Printf("starting activity check")
go checkForRecentActivity(ctx)
// shutdown
termChan := make(chan os.Signal, 1)
signal.Notify(termChan, syscall.SIGINT, syscall.SIGTERM)
<-termChan
log.Println("received signal, shutting down...")
defer log.Println("shut down.")
}
func checkIfBranchExists(ctx context.Context) {
if owner == "" || repo == "" || branch == "" {
log.Fatalf("one of owner/repo/branch (%s/%s/%s) is not properly configured!", owner, repo, branch)
}
if tokenEnvVarName == "" {
log.Fatal("tokenEnvVarName is not configured!")
}
token := os.Getenv(tokenEnvVarName)
if token == "" {
log.Fatalf("configured env var '%s' is empty!", tokenEnvVarName)
}
ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: token},
)
tc := oauth2.NewClient(ctx, ts)
client := github.NewClient(tc)
tick := time.NewTicker(period)
for {
gone, err := fetchBranch(ctx, client)
if err != nil {
log.Printf("unable to fetch branch, skipping: %v\n", err)
}
if gone {
log.Printf("branch %s in repo %s/%s is gone, executing command", branch, owner, repo)
_ = execute(command)
} else {
log.Printf("branch %s in repo %s/%s still present", branch, owner, repo)
}
select {
case <-ctx.Done():
return
case <-tick.C:
}
}
}
func fetchBranch(ctx context.Context, client *github.Client) (gone bool, err error) {
_, resp, err := client.Repositories.GetBranch(ctx, owner, repo, branch, true)
if resp != nil && resp.StatusCode == 404 {
return true, nil
}
if err != nil {
return false, err
}
return false, nil
}
func checkForRecentActivity(ctx context.Context) {
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s", dbUser, dbPass, dbHost, dbName))
if err != nil {
log.Fatalf("cannot connect to DB: %+v", err)
return
}
defer db.Close()
tick := time.NewTicker(period)
for {
t0 := getLastActivity(db)
if t0 == nil {
log.Fatalf("cannot determine last activity")
return
}
dt := time.Since(*t0)
log.Printf("last activity: %v (%s ago, %s until timeout)", t0.Format(time.RFC3339), dt.String(), (timeout - dt).String())
if dt > timeout {
log.Printf("timeout after %s, executing command: %s", dt.String(), command)
_ = execute(command)
}
select {
case <-ctx.Done():
return
case <-tick.C:
}
}
}
func getLastActivity(db *sql.DB) (lastActivity *time.Time) {
log.Printf("attempting to determine last time of activity")
srcs := []struct {
Name string
Query string
Format string
}{
{"latest instance", "SELECT creationTime FROM d_b_workspace_instance ORDER BY creationTime DESC LIMIT 1", time.RFC3339},
{"latest user", "SELECT creationDate FROM d_b_user ORDER BY creationDate DESC LIMIT 1", time.RFC3339},
{"heartbeat", "SELECT lastSeen FROM d_b_workspace_instance_user ORDER BY lastSeen DESC LIMIT 1", "2006-01-02 15:04:05.999999"},
}
for _, src := range srcs {
var rt string
err := db.QueryRow(src.Query).Scan(&rt)
if err != nil {
log.Printf("cannot query %s: %+v", src.Name, err)
continue
}
var t time.Time
t, err = time.Parse(src.Format, rt)
if err != nil {
log.Printf("cannot parse %s: %+v", src.Name, err)
continue
}
if lastActivity == nil || t.After(*lastActivity) {
lastActivity = &t
}
}
return lastActivity
}
func execute(command string) error {
segs := strings.Split(command, " ")
cmd := exec.Command(segs[0], segs[1:]...)
cmd.Env = os.Environ()
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
log.Printf("cannot run command \"%s\": %+v", command, err)
}
return err
}