feat: add first page with auth and containers list and agents

This commit is contained in:
2026-05-18 08:24:02 +02:00
parent 446087ae01
commit 3b4a841bf5
56 changed files with 16267 additions and 0 deletions

20
server/Dockerfile Normal file
View File

@ -0,0 +1,20 @@
FROM golang:1.23-alpine AS builder
RUN apk add --no-cache gcc musl-dev
WORKDIR /src
COPY go.mod go.sum ./
COPY . .
RUN go mod tidy && CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w" -o /bin/containarr-server ./cmd/server
# ── Runtime ───────────────────────────────────────────────────────────────────
FROM alpine:3.20
RUN apk add --no-cache ca-certificates tzdata
COPY --from=builder /bin/containarr-server /usr/local/bin/containarr-server
VOLUME ["/data"]
EXPOSE 8080 9090
ENTRYPOINT ["containarr-server"]

139
server/cmd/server/main.go Normal file
View File

@ -0,0 +1,139 @@
package main
import (
"context"
"log/slog"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"github.com/containarr/server/internal/api"
"github.com/containarr/server/internal/broker"
grpcgateway "github.com/containarr/server/internal/grpc"
agentv1 "github.com/containarr/server/internal/proto/agentv1"
"github.com/containarr/server/internal/store"
"google.golang.org/grpc"
)
func main() {
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil)))
dbPath := getenv("DB_PATH", "/data/containarr.db")
httpAddr := getenv("HTTP_ADDR", ":8080")
grpcAddr := getenv("GRPC_ADDR", ":9090")
db, err := store.New(dbPath)
must(err, "open store")
defer db.Close()
bootstrapAdmin(db)
bootstrapTokens(db)
reg := grpcgateway.NewRegistry()
brk := broker.New()
// gRPC server.
gw := grpcgateway.NewGateway(db, reg, brk)
grpcServer := grpc.NewServer()
agentv1.RegisterAgentGatewayServer(grpcServer, gw)
lis, err := net.Listen("tcp", grpcAddr)
must(err, "listen grpc")
go func() {
slog.Info("gRPC listening", "addr", grpcAddr)
if err := grpcServer.Serve(lis); err != nil {
slog.Error("gRPC serve", "err", err)
}
}()
// HTTP server.
h := api.NewHandler(db, reg, brk)
httpServer := &http.Server{
Addr: httpAddr,
Handler: api.NewRouter(h),
ReadTimeout: 10 * time.Second,
WriteTimeout: 0, // disabled for WebSocket handlers
}
go func() {
slog.Info("HTTP listening", "addr", httpAddr)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
slog.Error("HTTP serve", "err", err)
}
}()
// Graceful shutdown.
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
slog.Info("shutting down")
grpcServer.GracefulStop()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_ = httpServer.Shutdown(ctx)
}
func getenv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
// bootstrapAdmin creates the admin user from env vars if it doesn't exist yet.
func bootstrapAdmin(db *store.Store) {
username := getenv("ADMIN_USER", "admin")
password := getenv("ADMIN_PASSWORD", "admin")
exists, err := db.UserExists(username)
if err != nil || exists {
return
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
slog.Error("bcrypt admin", "err", err)
return
}
if err := db.UpsertUser(username, string(hash)); err != nil {
slog.Error("seed admin user", "err", err)
return
}
slog.Info("admin user created", "username", username)
}
// bootstrapTokens seeds agent tokens from BOOTSTRAP_TOKENS env var.
// Format: "hostname:token,hostname2:token2"
func bootstrapTokens(db *store.Store) {
raw := os.Getenv("BOOTSTRAP_TOKENS")
if raw == "" {
return
}
for _, pair := range strings.Split(raw, ",") {
parts := strings.SplitN(strings.TrimSpace(pair), ":", 2)
if len(parts) != 2 {
continue
}
hostname, token := parts[0], parts[1]
if err := db.CreateAgentToken(uuid.NewString(), token, hostname); err != nil {
slog.Warn("bootstrap token already exists", "hostname", hostname)
} else {
slog.Info("bootstrapped agent token", "hostname", hostname)
}
}
}
func must(err error, msg string) {
if err != nil {
slog.Error(msg, "err", err)
os.Exit(1)
}
}

21
server/go.mod Normal file
View File

@ -0,0 +1,21 @@
module github.com/containarr/server
go 1.23
require (
github.com/go-chi/chi/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/mattn/go-sqlite3 v1.14.22
golang.org/x/crypto v0.21.0
google.golang.org/grpc v1.64.0
google.golang.org/protobuf v1.34.2
)
require (
golang.org/x/net v0.22.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
)

26
server/go.sum Normal file
View File

@ -0,0 +1,26 @@
github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 h1:NnYq6UN9ReLM9/Y01KWNOWyI5xQ9kbIms5GGJVwS/Yc=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=
google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY=
google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=

View File

@ -0,0 +1,550 @@
package api
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/containarr/server/internal/broker"
grpcgateway "github.com/containarr/server/internal/grpc"
agentv1 "github.com/containarr/server/internal/proto/agentv1"
"github.com/containarr/server/internal/store"
"github.com/go-chi/chi/v5"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
)
// ── helpers ───────────────────────────────────────────────────────────────────
func newTestHandler(t *testing.T) (*Handler, *store.Store, *grpcgateway.Registry, *broker.Broker) {
t.Helper()
s, err := store.New(":memory:")
if err != nil {
t.Fatalf("store.New: %v", err)
}
t.Cleanup(func() { s.Close() })
reg := grpcgateway.NewRegistry()
b := broker.New()
h := NewHandler(s, reg, b)
return h, s, reg, b
}
func makeJWT(t *testing.T, subject string) string {
t.Helper()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwtClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: subject,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
})
signed, err := token.SignedString(jwtSecret())
if err != nil {
t.Fatalf("makeJWT: %v", err)
}
return signed
}
func bearerHeader(token string) string {
return "Bearer " + token
}
func postJSON(t *testing.T, handler http.HandlerFunc, path string, body any) *httptest.ResponseRecorder {
t.Helper()
b, _ := json.Marshal(body)
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler(w, req)
return w
}
func postJSONAuth(t *testing.T, handler http.HandlerFunc, path string, body any, token string) *httptest.ResponseRecorder {
t.Helper()
b, _ := json.Marshal(body)
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerHeader(token))
w := httptest.NewRecorder()
// Wrap handler with requireJWT so claims land in context.
requireJWT(handler).ServeHTTP(w, req)
return w
}
// ── extractToken ──────────────────────────────────────────────────────────────
func TestExtractToken_BearerHeader(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer mytoken")
if got := extractToken(req); got != "mytoken" {
t.Errorf("expected 'mytoken', got %q", got)
}
}
func TestExtractToken_QueryParam(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/?token=querytoken", nil)
if got := extractToken(req); got != "querytoken" {
t.Errorf("expected 'querytoken', got %q", got)
}
}
func TestExtractToken_Empty(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if got := extractToken(req); got != "" {
t.Errorf("expected empty, got %q", got)
}
}
func TestExtractToken_ShortAuthHeader(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bear") // len < 7
if got := extractToken(req); got != "" {
t.Errorf("expected empty for short header, got %q", got)
}
}
// ── requireJWT middleware ─────────────────────────────────────────────────────
func TestRequireJWT_MissingToken(t *testing.T) {
called := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true })
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
requireJWT(next).ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
if called {
t.Error("handler should not be called without token")
}
}
func TestRequireJWT_InvalidToken(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer not.a.real.token")
w := httptest.NewRecorder()
requireJWT(next).ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
}
func TestRequireJWT_ValidToken(t *testing.T) {
token := makeJWT(t, "alice")
called := false
var gotSubject string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
c, ok := claimsFromContext(r)
if !ok {
t.Error("claims not in context")
return
}
gotSubject = c.Subject
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", bearerHeader(token))
w := httptest.NewRecorder()
requireJWT(next).ServeHTTP(w, req)
if !called {
t.Error("handler was not called")
}
if gotSubject != "alice" {
t.Errorf("expected subject 'alice', got %q", gotSubject)
}
}
func TestRequireJWT_WrongSecret(t *testing.T) {
// Sign with a different secret
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwtClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: "hacker",
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
})
signed, _ := token.SignedString([]byte("wrong-secret"))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", bearerHeader(signed))
w := httptest.NewRecorder()
requireJWT(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})).ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
}
// ── Login ─────────────────────────────────────────────────────────────────────
func TestLogin_Success(t *testing.T) {
h, s, _, _ := newTestHandler(t)
hash, _ := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.MinCost)
_ = s.UpsertUser("alice", string(hash))
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
"username": "alice",
"password": "password",
})
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
}
var resp map[string]string
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode response: %v", err)
}
if resp["token"] == "" {
t.Error("expected non-empty token in response")
}
}
func TestLogin_WrongPassword(t *testing.T) {
h, s, _, _ := newTestHandler(t)
hash, _ := bcrypt.GenerateFromPassword([]byte("correct"), bcrypt.MinCost)
_ = s.UpsertUser("alice", string(hash))
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
"username": "alice",
"password": "wrong",
})
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
}
func TestLogin_UnknownUser(t *testing.T) {
h, _, _, _ := newTestHandler(t)
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
"username": "nobody",
"password": "pass",
})
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
}
func TestLogin_BadBody(t *testing.T) {
h, _, _, _ := newTestHandler(t)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader("not-json"))
w := httptest.NewRecorder()
h.Login(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestLogin_EmptyFields(t *testing.T) {
h, _, _, _ := newTestHandler(t)
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
"username": "",
"password": "",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
// ── ChangePassword ─────────────────────────────────────────────────────────────
func TestChangePassword_Success(t *testing.T) {
h, s, _, _ := newTestHandler(t)
os.Setenv("JWT_SECRET", "test-secret-change-pw")
defer os.Unsetenv("JWT_SECRET")
hash, _ := bcrypt.GenerateFromPassword([]byte("oldpass"), bcrypt.MinCost)
_ = s.UpsertUser("alice", string(hash))
token := makeJWT(t, "alice")
w := postJSONAuth(t, h.ChangePassword, "/api/v1/auth/change-password", map[string]string{
"current_password": "oldpass",
"new_password": "newpass",
}, token)
if w.Code != http.StatusNoContent {
t.Fatalf("expected 204, got %d — body: %s", w.Code, w.Body.String())
}
// Verify new hash is stored
newHash, _ := s.GetUserHash("alice")
if bcrypt.CompareHashAndPassword([]byte(newHash), []byte("newpass")) != nil {
t.Error("new password hash does not match")
}
}
func TestChangePassword_WrongCurrentPassword(t *testing.T) {
h, s, _, _ := newTestHandler(t)
hash, _ := bcrypt.GenerateFromPassword([]byte("correct"), bcrypt.MinCost)
_ = s.UpsertUser("alice", string(hash))
token := makeJWT(t, "alice")
w := postJSONAuth(t, h.ChangePassword, "/api/v1/auth/change-password", map[string]string{
"current_password": "wrong",
"new_password": "newpass",
}, token)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
}
// ── ListAgents ────────────────────────────────────────────────────────────────
func TestListAgents_Empty(t *testing.T) {
h, _, _, _ := newTestHandler(t)
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil)
w := httptest.NewRecorder()
h.ListAgents(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var agents []agentDTO
if err := json.NewDecoder(w.Body).Decode(&agents); err != nil {
t.Fatalf("decode: %v", err)
}
if len(agents) != 0 {
t.Errorf("expected empty list, got %d", len(agents))
}
}
func TestListAgents_PersistenceAndLive(t *testing.T) {
h, s, reg, _ := newTestHandler(t)
_ = s.CreateAgentToken("a1", "t1", "host1")
// Register a2 in the registry (simulating live agent)
reg.Register("a2", "host2", "alias2", "192.168.1.1", "arm64", "linux")
_ = s.CreateAgentToken("a2", "t2", "host2")
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil)
w := httptest.NewRecorder()
h.ListAgents(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var agents []agentDTO
json.NewDecoder(w.Body).Decode(&agents)
if len(agents) != 2 {
t.Fatalf("expected 2 agents, got %d", len(agents))
}
// Find a2 — it should be online
var a2 *agentDTO
for i := range agents {
if agents[i].ID == "a2" {
a2 = &agents[i]
}
}
if a2 == nil {
t.Fatal("a2 not found in list")
}
if !a2.Online {
t.Error("a2 should be online (registered in registry)")
}
}
// ── CreateAgentToken ──────────────────────────────────────────────────────────
func TestCreateAgentToken_Success(t *testing.T) {
h, _, _, _ := newTestHandler(t)
b, _ := json.Marshal(map[string]string{"hostname": "new-agent"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/token", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
h.CreateAgentToken(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
}
var resp map[string]string
json.NewDecoder(w.Body).Decode(&resp)
if resp["agent_id"] == "" || resp["token"] == "" {
t.Errorf("missing agent_id or token in response: %v", resp)
}
}
func TestCreateAgentToken_MissingHostname(t *testing.T) {
h, _, _, _ := newTestHandler(t)
b, _ := json.Marshal(map[string]string{"hostname": ""})
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/token", bytes.NewReader(b))
w := httptest.NewRecorder()
h.CreateAgentToken(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
// ── UpdateAgent ───────────────────────────────────────────────────────────────
func TestUpdateAgent_Success(t *testing.T) {
h, s, reg, _ := newTestHandler(t)
_ = s.CreateAgentToken("a1", "t1", "host1")
reg.Register("a1", "host1", "old", "ip", "arch", "os")
body, _ := json.Marshal(map[string]string{"alias": "new-alias"})
router := chi.NewRouter()
router.Patch("/api/v1/agents/{agentID}", h.UpdateAgent)
req, _ := http.NewRequest(http.MethodPatch, "/api/v1/agents/a1", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
}
var resp agentDTO
json.NewDecoder(w.Body).Decode(&resp)
if resp.Alias != "new-alias" {
t.Errorf("expected alias 'new-alias', got %q", resp.Alias)
}
// Confirm registry also updated
state, ok := reg.Get("a1")
if !ok {
t.Fatal("agent not in registry")
}
if state.Alias != "new-alias" {
t.Errorf("registry alias not updated, got %q", state.Alias)
}
}
// ── ListContainers ─────────────────────────────────────────────────────────────
func TestListContainers_Empty(t *testing.T) {
h, _, _, _ := newTestHandler(t)
req := httptest.NewRequest(http.MethodGet, "/api/v1/containers", nil)
w := httptest.NewRecorder()
h.ListContainers(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
}
func TestListContainers_WithData(t *testing.T) {
h, _, reg, _ := newTestHandler(t)
reg.Register("a1", "host1", "alias1", "10.0.0.1", "amd64", "linux")
reg.UpdateContainers("a1", []*agentv1.ContainerInfo{
{Id: "c1", Name: "web"},
{Id: "c2", Name: "db"},
})
req := httptest.NewRequest(http.MethodGet, "/api/v1/containers", nil)
w := httptest.NewRecorder()
h.ListContainers(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var out []struct {
AgentID string `json:"agent_id"`
Container *agentv1.ContainerInfo `json:"container"`
}
json.NewDecoder(w.Body).Decode(&out)
if len(out) != 2 {
t.Errorf("expected 2 containers, got %d", len(out))
}
}
// ── ContainerAction ───────────────────────────────────────────────────────────
func TestContainerAction_AgentNotConnected(t *testing.T) {
h, _, _, _ := newTestHandler(t)
body, _ := json.Marshal(map[string]string{"action": "start"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/ghost/containers/c1/action", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router := chi.NewRouter()
router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("expected 503, got %d", w.Code)
}
}
func TestContainerAction_InvalidAction(t *testing.T) {
h, _, reg, _ := newTestHandler(t)
reg.Register("a1", "h", "a", "ip", "arch", "os")
body, _ := json.Marshal(map[string]string{"action": "explode"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a1/containers/c1/action", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router := chi.NewRouter()
router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestContainerAction_Success(t *testing.T) {
h, _, reg, _ := newTestHandler(t)
reg.Register("a1", "h", "a", "ip", "arch", "os")
body, _ := json.Marshal(map[string]string{"action": "stop"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a1/containers/c1/action", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router := chi.NewRouter()
router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
}
var resp map[string]string
json.NewDecoder(w.Body).Decode(&resp)
if resp["command_id"] == "" {
t.Error("expected command_id in response")
}
}

125
server/internal/api/auth.go Normal file
View File

@ -0,0 +1,125 @@
package api
import (
"encoding/json"
"net/http"
"os"
"time"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
)
func jwtSecret() []byte {
if s := os.Getenv("JWT_SECRET"); s != "" {
return []byte(s)
}
return []byte("dev-secret-change-me")
}
type jwtClaims struct {
jwt.RegisteredClaims
}
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
var body struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Username == "" || body.Password == "" {
http.Error(w, "invalid body", http.StatusBadRequest)
return
}
hash, err := h.store.GetUserHash(body.Username)
if err != nil {
http.Error(w, "invalid credentials", http.StatusUnauthorized)
return
}
if bcrypt.CompareHashAndPassword([]byte(hash), []byte(body.Password)) != nil {
http.Error(w, "invalid credentials", http.StatusUnauthorized)
return
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwtClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: body.Username,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
})
signed, err := token.SignedString(jwtSecret())
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
jsonOK(w, map[string]string{"token": signed})
}
func (h *Handler) ChangePassword(w http.ResponseWriter, r *http.Request) {
claims, ok := claimsFromContext(r)
if !ok {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var body struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.NewPassword == "" {
http.Error(w, "invalid body", http.StatusBadRequest)
return
}
hash, err := h.store.GetUserHash(claims.Subject)
if err != nil {
http.Error(w, "user not found", http.StatusNotFound)
return
}
if bcrypt.CompareHashAndPassword([]byte(hash), []byte(body.CurrentPassword)) != nil {
http.Error(w, "mot de passe actuel incorrect", http.StatusUnauthorized)
return
}
newHash, err := bcrypt.GenerateFromPassword([]byte(body.NewPassword), bcrypt.DefaultCost)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if err := h.store.UpsertUser(claims.Subject, string(newHash)); err != nil {
http.Error(w, "store error", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
func requireJWT(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw := extractToken(r)
if raw == "" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
t, err := jwt.ParseWithClaims(raw, &jwtClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return jwtSecret(), nil
})
if err != nil || !t.Valid {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r.WithContext(
contextWithClaims(r.Context(), t.Claims.(*jwtClaims)),
))
})
}
func extractToken(r *http.Request) string {
if auth := r.Header.Get("Authorization"); len(auth) > 7 && auth[:7] == "Bearer " {
return auth[7:]
}
return r.URL.Query().Get("token")
}

View File

@ -0,0 +1,16 @@
package api
import "context"
type contextKey int
const claimsKey contextKey = iota
func contextWithClaims(ctx context.Context, c *jwtClaims) context.Context {
return context.WithValue(ctx, claimsKey, c)
}
func claimsFromContext(r interface{ Context() context.Context }) (*jwtClaims, bool) {
c, ok := r.Context().Value(claimsKey).(*jwtClaims)
return c, ok && c != nil
}

View File

@ -0,0 +1,301 @@
package api
import (
"encoding/json"
"net/http"
"strconv"
"time"
"github.com/containarr/server/internal/broker"
grpcgateway "github.com/containarr/server/internal/grpc"
agentv1 "github.com/containarr/server/internal/proto/agentv1"
"github.com/containarr/server/internal/store"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
type Handler struct {
store *store.Store
registry *grpcgateway.Registry
broker *broker.Broker
}
func NewHandler(s *store.Store, r *grpcgateway.Registry, b *broker.Broker) *Handler {
return &Handler{store: s, registry: r, broker: b}
}
type agentDTO struct {
ID string `json:"id"`
Hostname string `json:"hostname"`
Alias string `json:"alias"`
IPAddress string `json:"ip_address"`
Arch string `json:"arch"`
OS string `json:"os"`
Online bool `json:"online"`
LastSeenAt time.Time `json:"last_seen_at"`
}
// ── Agents ────────────────────────────────────────────────────────────────────
func (h *Handler) ListAgents(w http.ResponseWriter, r *http.Request) {
persisted, err := h.store.ListAgents()
if err != nil {
http.Error(w, "store error", http.StatusInternalServerError)
return
}
liveByID := map[string]*grpcgateway.AgentState{}
for _, s := range h.registry.List() {
liveByID[s.ID] = s
}
out := make([]agentDTO, 0, len(persisted))
for _, a := range persisted {
dto := agentDTO{
ID: a.ID,
Hostname: a.Hostname,
Alias: a.Alias,
IPAddress: a.IPAddress,
Arch: a.Arch,
OS: a.OS,
}
if live, ok := liveByID[a.ID]; ok {
dto.Online = true
dto.IPAddress = live.IPAddress
dto.LastSeenAt = live.LastSeenAt
}
out = append(out, dto)
}
jsonOK(w, out)
}
func (h *Handler) UpdateAgent(w http.ResponseWriter, r *http.Request) {
agentID := chi.URLParam(r, "agentID")
var body struct {
Alias string `json:"alias"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid body", http.StatusBadRequest)
return
}
if err := h.store.UpdateAgentAlias(agentID, body.Alias); err != nil {
http.Error(w, "store error", http.StatusInternalServerError)
return
}
h.registry.UpdateAlias(agentID, body.Alias)
a, err := h.store.GetAgent(agentID)
if err != nil {
http.Error(w, "not found", http.StatusNotFound)
return
}
jsonOK(w, agentDTO{
ID: a.ID,
Hostname: a.Hostname,
Alias: a.Alias,
IPAddress: a.IPAddress,
Arch: a.Arch,
OS: a.OS,
Online: a.Online,
})
}
// ── Containers ────────────────────────────────────────────────────────────────
func (h *Handler) ListContainers(w http.ResponseWriter, r *http.Request) {
type containerDTO struct {
AgentID string `json:"agent_id"`
Hostname string `json:"hostname"`
Alias string `json:"alias"`
IPAddress string `json:"ip_address"`
Container *agentv1.ContainerInfo `json:"container"`
}
var out []containerDTO
for _, agent := range h.registry.List() {
for _, c := range agent.Containers {
out = append(out, containerDTO{
AgentID: agent.ID,
Hostname: agent.Hostname,
Alias: agent.Alias,
IPAddress: agent.IPAddress,
Container: c,
})
}
}
jsonOK(w, out)
}
func (h *Handler) ContainerAction(w http.ResponseWriter, r *http.Request) {
agentID := chi.URLParam(r, "agentID")
containerID := chi.URLParam(r, "containerID")
var body struct {
Action string `json:"action"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid body", http.StatusBadRequest)
return
}
action, ok := map[string]agentv1.ContainerAction{
"start": agentv1.ContainerAction_CONTAINER_ACTION_START,
"stop": agentv1.ContainerAction_CONTAINER_ACTION_STOP,
"restart": agentv1.ContainerAction_CONTAINER_ACTION_RESTART,
"remove": agentv1.ContainerAction_CONTAINER_ACTION_REMOVE,
}[body.Action]
if !ok {
http.Error(w, "unknown action", http.StatusBadRequest)
return
}
cmdID := uuid.NewString()
sent := h.registry.Send(agentID, &agentv1.ServerMessage{
Payload: &agentv1.ServerMessage_ContainerCmd{
ContainerCmd: &agentv1.ContainerCommand{
CommandId: cmdID,
ContainerId: containerID,
Action: action,
},
},
})
if !sent {
http.Error(w, "agent not connected", http.StatusServiceUnavailable)
return
}
jsonOK(w, map[string]string{"command_id": cmdID})
}
// ── Agent token provisioning ──────────────────────────────────────────────────
func (h *Handler) CreateAgentToken(w http.ResponseWriter, r *http.Request) {
var body struct {
Hostname string `json:"hostname"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Hostname == "" {
http.Error(w, "hostname required", http.StatusBadRequest)
return
}
id := uuid.NewString()
token := uuid.NewString()
if err := h.store.CreateAgentToken(id, token, body.Hostname); err != nil {
http.Error(w, "store error", http.StatusInternalServerError)
return
}
jsonOK(w, map[string]string{"agent_id": id, "token": token})
}
// ── Container log stream ──────────────────────────────────────────────────────
func (h *Handler) LogsWS(w http.ResponseWriter, r *http.Request) {
agentID := chi.URLParam(r, "agentID")
containerID := chi.URLParam(r, "containerID")
follow := r.URL.Query().Get("follow") != "false"
tail := int32(100)
if t := r.URL.Query().Get("tail"); t != "" {
if n, err := strconv.Atoi(t); err == nil && n > 0 {
tail = int32(n)
}
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
sent := h.registry.Send(agentID, &agentv1.ServerMessage{
Payload: &agentv1.ServerMessage_StreamLogs{
StreamLogs: &agentv1.StreamLogsCommand{
CommandId: uuid.NewString(),
ContainerId: containerID,
Follow: follow,
Tail: tail,
},
},
})
if !sent {
conn.WriteMessage(websocket.TextMessage, []byte(`{"error":"agent not connected"}`))
return
}
sub := h.broker.Subscribe()
defer h.broker.Unsubscribe(sub)
done := make(chan struct{})
go func() {
defer close(done)
for {
if _, _, err := conn.ReadMessage(); err != nil {
return
}
}
}()
for {
select {
case <-done:
return
case raw, ok := <-sub:
if !ok {
return
}
var envelope struct {
Type string `json:"type"`
AgentID string `json:"agent_id"`
Payload json.RawMessage `json:"payload"`
}
if json.Unmarshal(raw, &envelope) != nil {
continue
}
if envelope.Type != "log.chunk" || envelope.AgentID != agentID {
continue
}
var chunk struct {
ContainerID string `json:"container_id"`
Stream string `json:"stream"`
Data []byte `json:"data"`
}
if json.Unmarshal(envelope.Payload, &chunk) != nil {
continue
}
if chunk.ContainerID != containerID {
continue
}
msg, _ := json.Marshal(map[string]string{
"stream": chunk.Stream,
"line": string(chunk.Data),
})
if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil {
return
}
}
}
}
// ── WebSocket event stream ────────────────────────────────────────────────────
func (h *Handler) EventsWS(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
sub := h.broker.Subscribe()
defer h.broker.Unsubscribe(sub)
for data := range sub {
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
return
}
}
}
func jsonOK(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(v)
}

View File

@ -0,0 +1,35 @@
package api
import (
"net/http"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
func NewRouter(h *Handler) http.Handler {
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Use(middleware.RealIP)
r.Route("/api/v1", func(r chi.Router) {
r.Post("/auth/login", h.Login)
r.Group(func(r chi.Router) {
r.Use(requireJWT)
r.Post("/auth/change-password", h.ChangePassword)
r.Get("/agents", h.ListAgents)
r.Post("/agents/token", h.CreateAgentToken)
r.Patch("/agents/{agentID}", h.UpdateAgent)
r.Get("/containers", h.ListContainers)
r.Post("/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
r.Get("/agents/{agentID}/containers/{containerID}/logs", h.LogsWS)
r.Get("/events", h.EventsWS)
})
})
r.Handle("/*", http.FileServer(http.Dir("./web/dist")))
return r
}

View File

@ -0,0 +1,49 @@
package auth
import (
"errors"
"time"
"github.com/golang-jwt/jwt/v5"
)
type Claims struct {
UserID string `json:"uid"`
jwt.RegisteredClaims
}
type Service struct {
secret []byte
}
func New(secret string) *Service {
return &Service{secret: []byte(secret)}
}
func (s *Service) Sign(userID string) (string, error) {
claims := Claims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(s.secret)
}
func (s *Service) Verify(tokenStr string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("unexpected signing method")
}
return s.secret, nil
})
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, errors.New("invalid token")
}
return claims, nil
}

View File

@ -0,0 +1,64 @@
package auth
import (
"testing"
"time"
)
func TestSignAndVerify(t *testing.T) {
svc := New("test-secret")
token, err := svc.Sign("user42")
if err != nil {
t.Fatalf("Sign: %v", err)
}
if token == "" {
t.Fatal("expected non-empty token")
}
claims, err := svc.Verify(token)
if err != nil {
t.Fatalf("Verify: %v", err)
}
if claims.UserID != "user42" {
t.Errorf("expected UserID 'user42', got %q", claims.UserID)
}
if claims.ExpiresAt == nil || claims.ExpiresAt.Before(time.Now()) {
t.Error("token should not be expired")
}
}
func TestVerify_InvalidToken(t *testing.T) {
svc := New("test-secret")
_, err := svc.Verify("not.a.valid.token")
if err == nil {
t.Fatal("expected error for invalid token")
}
}
func TestVerify_WrongSecret(t *testing.T) {
svc1 := New("secret-a")
svc2 := New("secret-b")
token, err := svc1.Sign("user1")
if err != nil {
t.Fatalf("Sign: %v", err)
}
_, err = svc2.Verify(token)
if err == nil {
t.Fatal("expected error when verifying with different secret")
}
}
func TestVerify_TamperedToken(t *testing.T) {
svc := New("test-secret")
token, _ := svc.Sign("admin")
// Append garbage to corrupt the signature.
tampered := token + "x"
_, err := svc.Verify(tampered)
if err == nil {
t.Fatal("expected error for tampered token")
}
}

View File

@ -0,0 +1,55 @@
package broker
import (
"encoding/json"
"sync"
)
// Event is a JSON-serialisable message pushed to WebSocket clients.
type Event struct {
Type string `json:"type"`
AgentID string `json:"agent_id,omitempty"`
Payload any `json:"payload"`
}
type subscriber chan []byte
// Broker fan-outs events to all registered WebSocket subscribers.
type Broker struct {
mu sync.RWMutex
subs map[subscriber]struct{}
}
func New() *Broker {
return &Broker{subs: make(map[subscriber]struct{})}
}
func (b *Broker) Subscribe() subscriber {
ch := make(subscriber, 32)
b.mu.Lock()
b.subs[ch] = struct{}{}
b.mu.Unlock()
return ch
}
func (b *Broker) Unsubscribe(ch subscriber) {
b.mu.Lock()
delete(b.subs, ch)
b.mu.Unlock()
close(ch)
}
func (b *Broker) Publish(evt Event) {
data, err := json.Marshal(evt)
if err != nil {
return
}
b.mu.RLock()
defer b.mu.RUnlock()
for ch := range b.subs {
select {
case ch <- data:
default: // drop if subscriber is slow
}
}
}

View File

@ -0,0 +1,123 @@
package broker
import (
"encoding/json"
"testing"
"time"
)
func TestSubscribePublishUnsubscribe(t *testing.T) {
b := New()
sub := b.Subscribe()
evt := Event{Type: "test.event", AgentID: "agent1", Payload: map[string]string{"k": "v"}}
b.Publish(evt)
select {
case raw := <-sub:
var got Event
if err := json.Unmarshal(raw, &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got.Type != "test.event" || got.AgentID != "agent1" {
t.Errorf("unexpected event: %+v", got)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
b.Unsubscribe(sub)
// channel must be closed after unsubscribe
select {
case _, ok := <-sub:
if ok {
t.Error("expected channel to be closed")
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for channel close")
}
}
func TestMultipleSubscribers(t *testing.T) {
b := New()
sub1 := b.Subscribe()
sub2 := b.Subscribe()
defer b.Unsubscribe(sub1)
defer b.Unsubscribe(sub2)
b.Publish(Event{Type: "ping", Payload: nil})
for i, sub := range []subscriber{sub1, sub2} {
select {
case <-sub:
case <-time.After(time.Second):
t.Fatalf("subscriber %d did not receive event", i)
}
}
}
func TestPublishDropsWhenSubscriberSlow(t *testing.T) {
b := New()
// Channel size is 32; fill it up and then publish one more — it must not block.
sub := b.Subscribe()
defer b.Unsubscribe(sub)
// Fill the buffer
for i := 0; i < 32; i++ {
b.Publish(Event{Type: "flood", Payload: i})
}
// This extra publish must return immediately (dropped, not block).
done := make(chan struct{})
go func() {
b.Publish(Event{Type: "dropped", Payload: nil})
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("Publish blocked on slow subscriber")
}
}
func TestPublishNoSubscribers(t *testing.T) {
b := New()
// Should not panic or block
b.Publish(Event{Type: "nobody", Payload: nil})
}
func TestPublishInvalidPayload(t *testing.T) {
b := New()
sub := b.Subscribe()
defer b.Unsubscribe(sub)
// json.Marshal of a channel fails — Publish must not send anything.
b.Publish(Event{Type: "bad", Payload: make(chan int)})
select {
case <-sub:
t.Error("should not have received a message for an unmarshalable event")
default:
// correct: nothing sent
}
}
func TestUnsubscribeRemovesFromBroker(t *testing.T) {
b := New()
sub := b.Subscribe()
b.Unsubscribe(sub)
// After unsubscribe the broker's map should be empty.
b.mu.RLock()
n := len(b.subs)
b.mu.RUnlock()
if n != 0 {
t.Errorf("expected 0 subscribers after unsubscribe, got %d", n)
}
}

View File

@ -0,0 +1,137 @@
package grpc
import (
"io"
"log/slog"
"net"
"github.com/containarr/server/internal/broker"
agentv1 "github.com/containarr/server/internal/proto/agentv1"
"github.com/containarr/server/internal/store"
"github.com/google/uuid"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
type Gateway struct {
agentv1.UnimplementedAgentGatewayServer
store *store.Store
registry *Registry
broker *broker.Broker
}
func NewGateway(s *store.Store, r *Registry, b *broker.Broker) *Gateway {
return &Gateway{store: s, registry: r, broker: b}
}
func (g *Gateway) Tunnel(stream agentv1.AgentGateway_TunnelServer) error {
if err := stream.SendHeader(metadata.MD{}); err != nil {
return status.Errorf(codes.Internal, "send header: %v", err)
}
first, err := stream.Recv()
if err != nil {
return err
}
hs := first.GetHandshake()
if hs == nil {
return status.Error(codes.InvalidArgument, "first message must be AgentHandshake")
}
existing, err := g.store.AgentByToken(hs.Token)
if err != nil {
return status.Error(codes.Unauthenticated, "unknown agent token")
}
// Extract peer IP from the gRPC connection.
ipAddress := ""
if p, ok := peer.FromContext(stream.Context()); ok {
if host, _, err := net.SplitHostPort(p.Addr.String()); err == nil {
ipAddress = host
}
}
agentID := existing.ID
slog.Info("agent connected", "id", agentID, "hostname", hs.Hostname, "ip", ipAddress)
state := g.registry.Register(agentID, hs.Hostname, existing.Alias, ipAddress, hs.Arch, hs.Os)
_ = g.store.UpsertAgent(&store.Agent{
ID: agentID,
Token: hs.Token,
Hostname: hs.Hostname,
Alias: existing.Alias,
IPAddress: ipAddress,
Arch: hs.Arch,
OS: hs.Os,
Online: true,
})
g.broker.Publish(broker.Event{
Type: "agent.connected",
AgentID: agentID,
Payload: map[string]string{"hostname": hs.Hostname},
})
defer func() {
g.registry.Deregister(agentID)
_ = g.store.SetAgentOffline(agentID)
g.broker.Publish(broker.Event{Type: "agent.disconnected", AgentID: agentID, Payload: nil})
slog.Info("agent disconnected", "id", agentID)
}()
errCh := make(chan error, 1)
go func() {
for msg := range state.cmdCh {
if err := stream.Send(msg); err != nil {
errCh <- err
return
}
}
}()
for {
select {
case err := <-errCh:
return err
default:
}
msg, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
switch p := msg.Payload.(type) {
case *agentv1.AgentMessage_Snapshot:
g.registry.UpdateContainers(agentID, p.Snapshot.Containers)
g.broker.Publish(broker.Event{
Type: "containers.updated",
AgentID: agentID,
Payload: p.Snapshot.Containers,
})
case *agentv1.AgentMessage_Result:
g.broker.Publish(broker.Event{
Type: "command.result",
AgentID: agentID,
Payload: p.Result,
})
case *agentv1.AgentMessage_LogChunk:
g.broker.Publish(broker.Event{
Type: "log.chunk",
AgentID: agentID,
Payload: p.LogChunk,
})
}
}
}
func newCommandID() string {
return uuid.NewString()
}

View File

@ -0,0 +1,105 @@
package grpc
import (
"sync"
"time"
agentv1 "github.com/containarr/server/internal/proto/agentv1"
)
type AgentState struct {
ID string
Hostname string
Alias string
IPAddress string
Arch string
OS string
LastSeenAt time.Time
Containers []*agentv1.ContainerInfo
cmdCh chan *agentv1.ServerMessage
}
type Registry struct {
mu sync.RWMutex
agents map[string]*AgentState
}
func NewRegistry() *Registry {
return &Registry{agents: make(map[string]*AgentState)}
}
func (r *Registry) Register(id, hostname, alias, ipAddress, arch, os string) *AgentState {
state := &AgentState{
ID: id,
Hostname: hostname,
Alias: alias,
IPAddress: ipAddress,
Arch: arch,
OS: os,
cmdCh: make(chan *agentv1.ServerMessage, 16),
}
r.mu.Lock()
r.agents[id] = state
r.mu.Unlock()
return state
}
func (r *Registry) Deregister(id string) {
r.mu.Lock()
if s, ok := r.agents[id]; ok {
close(s.cmdCh)
delete(r.agents, id)
}
r.mu.Unlock()
}
func (r *Registry) Get(id string) (*AgentState, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
s, ok := r.agents[id]
return s, ok
}
func (r *Registry) List() []*AgentState {
r.mu.RLock()
defer r.mu.RUnlock()
out := make([]*AgentState, 0, len(r.agents))
for _, s := range r.agents {
out = append(out, s)
}
return out
}
func (r *Registry) UpdateContainers(id string, containers []*agentv1.ContainerInfo) {
r.mu.Lock()
defer r.mu.Unlock()
if s, ok := r.agents[id]; ok {
s.Containers = containers
s.LastSeenAt = time.Now()
}
}
// UpdateAlias refreshes the alias for a live agent (called after an admin update).
func (r *Registry) UpdateAlias(id, alias string) {
r.mu.Lock()
defer r.mu.Unlock()
if s, ok := r.agents[id]; ok {
s.Alias = alias
}
}
func (r *Registry) Send(agentID string, msg *agentv1.ServerMessage) bool {
r.mu.RLock()
s, ok := r.agents[agentID]
r.mu.RUnlock()
if !ok {
return false
}
select {
case s.cmdCh <- msg:
return true
default:
return false
}
}

View File

@ -0,0 +1,155 @@
package grpc
import (
"testing"
"time"
agentv1 "github.com/containarr/server/internal/proto/agentv1"
)
func TestRegisterAndGet(t *testing.T) {
r := NewRegistry()
state := r.Register("id1", "hostname1", "alias1", "10.0.0.1", "amd64", "linux")
if state == nil {
t.Fatal("Register returned nil")
}
got, ok := r.Get("id1")
if !ok {
t.Fatal("Get returned false for registered agent")
}
if got.ID != "id1" || got.Hostname != "hostname1" || got.Alias != "alias1" {
t.Errorf("unexpected state: %+v", got)
}
}
func TestGet_NotFound(t *testing.T) {
r := NewRegistry()
_, ok := r.Get("nonexistent")
if ok {
t.Error("expected false for unknown agent")
}
}
func TestDeregister(t *testing.T) {
r := NewRegistry()
r.Register("id1", "h", "a", "ip", "arch", "os")
r.Deregister("id1")
_, ok := r.Get("id1")
if ok {
t.Error("agent should not exist after Deregister")
}
}
func TestDeregister_NotExist(t *testing.T) {
r := NewRegistry()
// must not panic
r.Deregister("ghost")
}
func TestList(t *testing.T) {
r := NewRegistry()
if len(r.List()) != 0 {
t.Error("expected empty list")
}
r.Register("a1", "h1", "", "", "", "")
r.Register("a2", "h2", "", "", "", "")
if len(r.List()) != 2 {
t.Errorf("expected 2 agents, got %d", len(r.List()))
}
}
func TestUpdateContainers(t *testing.T) {
r := NewRegistry()
r.Register("id1", "h", "a", "ip", "arch", "os")
before := time.Now()
containers := []*agentv1.ContainerInfo{
{Id: "c1", Name: "web"},
{Id: "c2", Name: "db"},
}
r.UpdateContainers("id1", containers)
got, _ := r.Get("id1")
if len(got.Containers) != 2 {
t.Errorf("expected 2 containers, got %d", len(got.Containers))
}
if got.LastSeenAt.Before(before) {
t.Error("LastSeenAt should have been updated")
}
}
func TestUpdateContainers_UnknownAgent(t *testing.T) {
r := NewRegistry()
// must not panic
r.UpdateContainers("ghost", nil)
}
func TestUpdateAlias(t *testing.T) {
r := NewRegistry()
r.Register("id1", "h", "old-alias", "ip", "arch", "os")
r.UpdateAlias("id1", "new-alias")
got, _ := r.Get("id1")
if got.Alias != "new-alias" {
t.Errorf("expected 'new-alias', got %q", got.Alias)
}
}
func TestUpdateAlias_UnknownAgent(t *testing.T) {
r := NewRegistry()
// must not panic
r.UpdateAlias("ghost", "alias")
}
func TestSend(t *testing.T) {
r := NewRegistry()
state := r.Register("id1", "h", "a", "ip", "arch", "os")
msg := &agentv1.ServerMessage{}
ok := r.Send("id1", msg)
if !ok {
t.Fatal("Send returned false for connected agent")
}
// Drain the channel to verify the message arrived.
select {
case got := <-state.cmdCh:
if got != msg {
t.Error("received wrong message")
}
case <-time.After(time.Second):
t.Fatal("timed out reading from cmdCh")
}
}
func TestSend_UnknownAgent(t *testing.T) {
r := NewRegistry()
ok := r.Send("ghost", &agentv1.ServerMessage{})
if ok {
t.Error("Send should return false for unknown agent")
}
}
func TestSend_FullChannel(t *testing.T) {
r := NewRegistry()
r.Register("id1", "h", "a", "ip", "arch", "os")
// Fill the buffer (size 16)
for i := 0; i < 16; i++ {
r.Send("id1", &agentv1.ServerMessage{})
}
// Next send on a full channel should return false
ok := r.Send("id1", &agentv1.ServerMessage{})
if ok {
t.Error("Send should return false when channel is full")
}
}

View File

@ -0,0 +1,183 @@
package store
import (
"database/sql"
"time"
_ "github.com/mattn/go-sqlite3"
)
type Agent struct {
ID string
Token string
Hostname string
Alias string
IPAddress string
Arch string
OS string
LastSeenAt time.Time
Online bool
}
type Store struct {
db *sql.DB
}
func New(path string) (*Store, error) {
db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_foreign_keys=on")
if err != nil {
return nil, err
}
s := &Store{db: db}
return s, s.migrate()
}
func (s *Store) migrate() error {
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS users (
username TEXT PRIMARY KEY,
password_hash TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS agents (
id TEXT PRIMARY KEY,
token TEXT UNIQUE NOT NULL,
hostname TEXT NOT NULL,
alias TEXT NOT NULL DEFAULT '',
ip_address TEXT NOT NULL DEFAULT '',
arch TEXT NOT NULL DEFAULT '',
os TEXT NOT NULL DEFAULT '',
last_seen_at DATETIME,
online INTEGER NOT NULL DEFAULT 0
);
`)
if err != nil {
return err
}
// Idempotent — ignore error if column already exists.
for _, col := range []string{
`ALTER TABLE agents ADD COLUMN alias TEXT NOT NULL DEFAULT ''`,
`ALTER TABLE agents ADD COLUMN ip_address TEXT NOT NULL DEFAULT ''`,
} {
s.db.Exec(col)
}
return nil
}
func (s *Store) Close() error { return s.db.Close() }
func (s *Store) UpsertAgent(a *Agent) error {
_, err := s.db.Exec(`
INSERT INTO agents (id, token, hostname, alias, ip_address, arch, os, last_seen_at, online)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(token) DO UPDATE SET
hostname = excluded.hostname,
ip_address = excluded.ip_address,
arch = excluded.arch,
os = excluded.os,
last_seen_at = excluded.last_seen_at,
online = excluded.online
`, a.ID, a.Token, a.Hostname, a.Alias, a.IPAddress, a.Arch, a.OS, a.LastSeenAt, boolToInt(a.Online))
return err
}
func (s *Store) AgentByToken(token string) (*Agent, error) {
row := s.db.QueryRow(`
SELECT id, token, hostname, alias, ip_address, arch, os, last_seen_at, online
FROM agents WHERE token = ?`, token)
return scanAgent(row)
}
func (s *Store) GetAgent(id string) (*Agent, error) {
row := s.db.QueryRow(`
SELECT id, token, hostname, alias, ip_address, arch, os, last_seen_at, online
FROM agents WHERE id = ?`, id)
return scanAgent(row)
}
func (s *Store) ListAgents() ([]*Agent, error) {
rows, err := s.db.Query(`
SELECT id, token, hostname, alias, ip_address, arch, os, last_seen_at, online
FROM agents ORDER BY hostname`)
if err != nil {
return nil, err
}
defer rows.Close()
var agents []*Agent
for rows.Next() {
a := &Agent{}
var online int
var lastSeen sql.NullTime
if err := rows.Scan(&a.ID, &a.Token, &a.Hostname, &a.Alias, &a.IPAddress, &a.Arch, &a.OS, &lastSeen, &online); err != nil {
return nil, err
}
if lastSeen.Valid {
a.LastSeenAt = lastSeen.Time
}
a.Online = online == 1
agents = append(agents, a)
}
return agents, rows.Err()
}
func (s *Store) SetAgentOffline(id string) error {
_, err := s.db.Exec(`UPDATE agents SET online = 0 WHERE id = ?`, id)
return err
}
func (s *Store) CreateAgentToken(id, token, hostname string) error {
_, err := s.db.Exec(`
INSERT OR IGNORE INTO agents (id, token, hostname, arch, os, online)
VALUES (?, ?, ?, '', '', 0)
`, id, token, hostname)
return err
}
func (s *Store) UpdateAgentAlias(id, alias string) error {
_, err := s.db.Exec(`UPDATE agents SET alias = ? WHERE id = ?`, alias, id)
return err
}
// ── Users ─────────────────────────────────────────────────────────────────────
func (s *Store) GetUserHash(username string) (string, error) {
var hash string
err := s.db.QueryRow(`SELECT password_hash FROM users WHERE username = ?`, username).Scan(&hash)
return hash, err
}
func (s *Store) UpsertUser(username, hash string) error {
_, err := s.db.Exec(`
INSERT INTO users (username, password_hash) VALUES (?, ?)
ON CONFLICT(username) DO UPDATE SET password_hash = excluded.password_hash
`, username, hash)
return err
}
func (s *Store) UserExists(username string) (bool, error) {
var n int
err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE username = ?`, username).Scan(&n)
return n > 0, err
}
func scanAgent(row *sql.Row) (*Agent, error) {
a := &Agent{}
var online int
var lastSeen sql.NullTime
err := row.Scan(&a.ID, &a.Token, &a.Hostname, &a.Alias, &a.IPAddress, &a.Arch, &a.OS, &lastSeen, &online)
if err != nil {
return nil, err
}
if lastSeen.Valid {
a.LastSeenAt = lastSeen.Time
}
a.Online = online == 1
return a, nil
}
func boolToInt(b bool) int {
if b {
return 1
}
return 0
}

View File

@ -0,0 +1,208 @@
package store
import (
"testing"
)
func newTestStore(t *testing.T) *Store {
t.Helper()
s, err := New(":memory:")
if err != nil {
t.Fatalf("failed to open in-memory store: %v", err)
}
t.Cleanup(func() { s.Close() })
return s
}
// ── Users ─────────────────────────────────────────────────────────────────────
func TestUpsertAndGetUserHash(t *testing.T) {
s := newTestStore(t)
if err := s.UpsertUser("alice", "hash123"); err != nil {
t.Fatalf("UpsertUser: %v", err)
}
h, err := s.GetUserHash("alice")
if err != nil {
t.Fatalf("GetUserHash: %v", err)
}
if h != "hash123" {
t.Errorf("expected hash123, got %q", h)
}
}
func TestGetUserHash_NotFound(t *testing.T) {
s := newTestStore(t)
_, err := s.GetUserHash("nobody")
if err == nil {
t.Fatal("expected error for missing user, got nil")
}
}
func TestUpsertUser_Update(t *testing.T) {
s := newTestStore(t)
if err := s.UpsertUser("alice", "first"); err != nil {
t.Fatalf("UpsertUser: %v", err)
}
if err := s.UpsertUser("alice", "second"); err != nil {
t.Fatalf("UpsertUser update: %v", err)
}
h, err := s.GetUserHash("alice")
if err != nil {
t.Fatalf("GetUserHash: %v", err)
}
if h != "second" {
t.Errorf("expected second, got %q", h)
}
}
func TestUserExists(t *testing.T) {
s := newTestStore(t)
ok, err := s.UserExists("alice")
if err != nil {
t.Fatalf("UserExists: %v", err)
}
if ok {
t.Error("expected false for non-existent user")
}
_ = s.UpsertUser("alice", "hash")
ok, err = s.UserExists("alice")
if err != nil {
t.Fatalf("UserExists: %v", err)
}
if !ok {
t.Error("expected true after insert")
}
}
// ── Agents ────────────────────────────────────────────────────────────────────
func TestCreateAgentToken(t *testing.T) {
s := newTestStore(t)
if err := s.CreateAgentToken("id1", "tok1", "host1"); err != nil {
t.Fatalf("CreateAgentToken: %v", err)
}
a, err := s.AgentByToken("tok1")
if err != nil {
t.Fatalf("AgentByToken: %v", err)
}
if a.ID != "id1" || a.Hostname != "host1" {
t.Errorf("unexpected agent: %+v", a)
}
}
func TestAgentByToken_NotFound(t *testing.T) {
s := newTestStore(t)
_, err := s.AgentByToken("doesnotexist")
if err == nil {
t.Fatal("expected error for unknown token")
}
}
func TestUpsertAgent(t *testing.T) {
s := newTestStore(t)
a := &Agent{
ID: "agent1",
Token: "tok1",
Hostname: "myhost",
Alias: "myalias",
IPAddress: "10.0.0.1",
Arch: "amd64",
OS: "linux",
Online: true,
}
if err := s.UpsertAgent(a); err != nil {
t.Fatalf("UpsertAgent: %v", err)
}
got, err := s.GetAgent("agent1")
if err != nil {
t.Fatalf("GetAgent: %v", err)
}
if got.Hostname != "myhost" || got.Alias != "myalias" {
t.Errorf("unexpected agent: %+v", got)
}
}
func TestListAgents(t *testing.T) {
s := newTestStore(t)
_ = s.CreateAgentToken("a1", "t1", "host-b")
_ = s.CreateAgentToken("a2", "t2", "host-a")
agents, err := s.ListAgents()
if err != nil {
t.Fatalf("ListAgents: %v", err)
}
if len(agents) != 2 {
t.Fatalf("expected 2 agents, got %d", len(agents))
}
// ORDER BY hostname: host-a < host-b
if agents[0].Hostname != "host-a" || agents[1].Hostname != "host-b" {
t.Errorf("unexpected order: %v %v", agents[0].Hostname, agents[1].Hostname)
}
}
func TestSetAgentOffline(t *testing.T) {
s := newTestStore(t)
_ = s.UpsertAgent(&Agent{
ID: "a1",
Token: "t1",
Hostname: "h1",
Online: true,
})
if err := s.SetAgentOffline("a1"); err != nil {
t.Fatalf("SetAgentOffline: %v", err)
}
a, err := s.GetAgent("a1")
if err != nil {
t.Fatalf("GetAgent: %v", err)
}
if a.Online {
t.Error("expected Online=false after SetAgentOffline")
}
}
func TestUpdateAgentAlias(t *testing.T) {
s := newTestStore(t)
_ = s.CreateAgentToken("a1", "t1", "host1")
if err := s.UpdateAgentAlias("a1", "newalias"); err != nil {
t.Fatalf("UpdateAgentAlias: %v", err)
}
a, err := s.GetAgent("a1")
if err != nil {
t.Fatalf("GetAgent: %v", err)
}
if a.Alias != "newalias" {
t.Errorf("expected alias 'newalias', got %q", a.Alias)
}
}
func TestCreateAgentToken_IdempotentIgnore(t *testing.T) {
s := newTestStore(t)
// INSERT OR IGNORE — second call should not error
if err := s.CreateAgentToken("id1", "tok1", "h1"); err != nil {
t.Fatalf("first call: %v", err)
}
if err := s.CreateAgentToken("id1", "tok1", "h1"); err != nil {
t.Fatalf("second call (should be idempotent): %v", err)
}
}