feat: add first page with auth and containers list and agents
This commit is contained in:
550
server/internal/api/api_test.go
Normal file
550
server/internal/api/api_test.go
Normal file
@ -0,0 +1,550 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containarr/server/internal/broker"
|
||||
grpcgateway "github.com/containarr/server/internal/grpc"
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
"github.com/containarr/server/internal/store"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func newTestHandler(t *testing.T) (*Handler, *store.Store, *grpcgateway.Registry, *broker.Broker) {
|
||||
t.Helper()
|
||||
s, err := store.New(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("store.New: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { s.Close() })
|
||||
|
||||
reg := grpcgateway.NewRegistry()
|
||||
b := broker.New()
|
||||
h := NewHandler(s, reg, b)
|
||||
return h, s, reg, b
|
||||
}
|
||||
|
||||
func makeJWT(t *testing.T, subject string) string {
|
||||
t.Helper()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwtClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: subject,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
})
|
||||
signed, err := token.SignedString(jwtSecret())
|
||||
if err != nil {
|
||||
t.Fatalf("makeJWT: %v", err)
|
||||
}
|
||||
return signed
|
||||
}
|
||||
|
||||
func bearerHeader(token string) string {
|
||||
return "Bearer " + token
|
||||
}
|
||||
|
||||
func postJSON(t *testing.T, handler http.HandlerFunc, path string, body any) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
b, _ := json.Marshal(body)
|
||||
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
handler(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
func postJSONAuth(t *testing.T, handler http.HandlerFunc, path string, body any, token string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
b, _ := json.Marshal(body)
|
||||
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerHeader(token))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Wrap handler with requireJWT so claims land in context.
|
||||
requireJWT(handler).ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
// ── extractToken ──────────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractToken_BearerHeader(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer mytoken")
|
||||
if got := extractToken(req); got != "mytoken" {
|
||||
t.Errorf("expected 'mytoken', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToken_QueryParam(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/?token=querytoken", nil)
|
||||
if got := extractToken(req); got != "querytoken" {
|
||||
t.Errorf("expected 'querytoken', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToken_Empty(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if got := extractToken(req); got != "" {
|
||||
t.Errorf("expected empty, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToken_ShortAuthHeader(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bear") // len < 7
|
||||
if got := extractToken(req); got != "" {
|
||||
t.Errorf("expected empty for short header, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ── requireJWT middleware ─────────────────────────────────────────────────────
|
||||
|
||||
func TestRequireJWT_MissingToken(t *testing.T) {
|
||||
called := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
requireJWT(next).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
if called {
|
||||
t.Error("handler should not be called without token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireJWT_InvalidToken(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer not.a.real.token")
|
||||
w := httptest.NewRecorder()
|
||||
requireJWT(next).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireJWT_ValidToken(t *testing.T) {
|
||||
token := makeJWT(t, "alice")
|
||||
called := false
|
||||
var gotSubject string
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
c, ok := claimsFromContext(r)
|
||||
if !ok {
|
||||
t.Error("claims not in context")
|
||||
return
|
||||
}
|
||||
gotSubject = c.Subject
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", bearerHeader(token))
|
||||
w := httptest.NewRecorder()
|
||||
requireJWT(next).ServeHTTP(w, req)
|
||||
|
||||
if !called {
|
||||
t.Error("handler was not called")
|
||||
}
|
||||
if gotSubject != "alice" {
|
||||
t.Errorf("expected subject 'alice', got %q", gotSubject)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireJWT_WrongSecret(t *testing.T) {
|
||||
// Sign with a different secret
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwtClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: "hacker",
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
})
|
||||
signed, _ := token.SignedString([]byte("wrong-secret"))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", bearerHeader(signed))
|
||||
w := httptest.NewRecorder()
|
||||
requireJWT(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Login ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestLogin_Success(t *testing.T) {
|
||||
h, s, _, _ := newTestHandler(t)
|
||||
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.MinCost)
|
||||
_ = s.UpsertUser("alice", string(hash))
|
||||
|
||||
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
|
||||
"username": "alice",
|
||||
"password": "password",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if resp["token"] == "" {
|
||||
t.Error("expected non-empty token in response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin_WrongPassword(t *testing.T) {
|
||||
h, s, _, _ := newTestHandler(t)
|
||||
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte("correct"), bcrypt.MinCost)
|
||||
_ = s.UpsertUser("alice", string(hash))
|
||||
|
||||
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
|
||||
"username": "alice",
|
||||
"password": "wrong",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin_UnknownUser(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
|
||||
"username": "nobody",
|
||||
"password": "pass",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin_BadBody(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader("not-json"))
|
||||
w := httptest.NewRecorder()
|
||||
h.Login(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin_EmptyFields(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
|
||||
"username": "",
|
||||
"password": "",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── ChangePassword ─────────────────────────────────────────────────────────────
|
||||
|
||||
func TestChangePassword_Success(t *testing.T) {
|
||||
h, s, _, _ := newTestHandler(t)
|
||||
os.Setenv("JWT_SECRET", "test-secret-change-pw")
|
||||
defer os.Unsetenv("JWT_SECRET")
|
||||
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte("oldpass"), bcrypt.MinCost)
|
||||
_ = s.UpsertUser("alice", string(hash))
|
||||
|
||||
token := makeJWT(t, "alice")
|
||||
w := postJSONAuth(t, h.ChangePassword, "/api/v1/auth/change-password", map[string]string{
|
||||
"current_password": "oldpass",
|
||||
"new_password": "newpass",
|
||||
}, token)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Fatalf("expected 204, got %d — body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify new hash is stored
|
||||
newHash, _ := s.GetUserHash("alice")
|
||||
if bcrypt.CompareHashAndPassword([]byte(newHash), []byte("newpass")) != nil {
|
||||
t.Error("new password hash does not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangePassword_WrongCurrentPassword(t *testing.T) {
|
||||
h, s, _, _ := newTestHandler(t)
|
||||
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte("correct"), bcrypt.MinCost)
|
||||
_ = s.UpsertUser("alice", string(hash))
|
||||
|
||||
token := makeJWT(t, "alice")
|
||||
w := postJSONAuth(t, h.ChangePassword, "/api/v1/auth/change-password", map[string]string{
|
||||
"current_password": "wrong",
|
||||
"new_password": "newpass",
|
||||
}, token)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── ListAgents ────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestListAgents_Empty(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ListAgents(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var agents []agentDTO
|
||||
if err := json.NewDecoder(w.Body).Decode(&agents); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(agents) != 0 {
|
||||
t.Errorf("expected empty list, got %d", len(agents))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAgents_PersistenceAndLive(t *testing.T) {
|
||||
h, s, reg, _ := newTestHandler(t)
|
||||
|
||||
_ = s.CreateAgentToken("a1", "t1", "host1")
|
||||
// Register a2 in the registry (simulating live agent)
|
||||
reg.Register("a2", "host2", "alias2", "192.168.1.1", "arm64", "linux")
|
||||
_ = s.CreateAgentToken("a2", "t2", "host2")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ListAgents(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var agents []agentDTO
|
||||
json.NewDecoder(w.Body).Decode(&agents)
|
||||
|
||||
if len(agents) != 2 {
|
||||
t.Fatalf("expected 2 agents, got %d", len(agents))
|
||||
}
|
||||
|
||||
// Find a2 — it should be online
|
||||
var a2 *agentDTO
|
||||
for i := range agents {
|
||||
if agents[i].ID == "a2" {
|
||||
a2 = &agents[i]
|
||||
}
|
||||
}
|
||||
if a2 == nil {
|
||||
t.Fatal("a2 not found in list")
|
||||
}
|
||||
if !a2.Online {
|
||||
t.Error("a2 should be online (registered in registry)")
|
||||
}
|
||||
}
|
||||
|
||||
// ── CreateAgentToken ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestCreateAgentToken_Success(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
b, _ := json.Marshal(map[string]string{"hostname": "new-agent"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/token", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateAgentToken(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
|
||||
if resp["agent_id"] == "" || resp["token"] == "" {
|
||||
t.Errorf("missing agent_id or token in response: %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAgentToken_MissingHostname(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
b, _ := json.Marshal(map[string]string{"hostname": ""})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/token", bytes.NewReader(b))
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateAgentToken(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── UpdateAgent ───────────────────────────────────────────────────────────────
|
||||
|
||||
func TestUpdateAgent_Success(t *testing.T) {
|
||||
h, s, reg, _ := newTestHandler(t)
|
||||
|
||||
_ = s.CreateAgentToken("a1", "t1", "host1")
|
||||
reg.Register("a1", "host1", "old", "ip", "arch", "os")
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"alias": "new-alias"})
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Patch("/api/v1/agents/{agentID}", h.UpdateAgent)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPatch, "/api/v1/agents/a1", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp agentDTO
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
if resp.Alias != "new-alias" {
|
||||
t.Errorf("expected alias 'new-alias', got %q", resp.Alias)
|
||||
}
|
||||
|
||||
// Confirm registry also updated
|
||||
state, ok := reg.Get("a1")
|
||||
if !ok {
|
||||
t.Fatal("agent not in registry")
|
||||
}
|
||||
if state.Alias != "new-alias" {
|
||||
t.Errorf("registry alias not updated, got %q", state.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
// ── ListContainers ─────────────────────────────────────────────────────────────
|
||||
|
||||
func TestListContainers_Empty(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/containers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ListContainers(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListContainers_WithData(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
|
||||
reg.Register("a1", "host1", "alias1", "10.0.0.1", "amd64", "linux")
|
||||
reg.UpdateContainers("a1", []*agentv1.ContainerInfo{
|
||||
{Id: "c1", Name: "web"},
|
||||
{Id: "c2", Name: "db"},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/containers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ListContainers(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var out []struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Container *agentv1.ContainerInfo `json:"container"`
|
||||
}
|
||||
json.NewDecoder(w.Body).Decode(&out)
|
||||
|
||||
if len(out) != 2 {
|
||||
t.Errorf("expected 2 containers, got %d", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// ── ContainerAction ───────────────────────────────────────────────────────────
|
||||
|
||||
func TestContainerAction_AgentNotConnected(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"action": "start"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/ghost/containers/c1/action", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainerAction_InvalidAction(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
reg.Register("a1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"action": "explode"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a1/containers/c1/action", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainerAction_Success(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
reg.Register("a1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"action": "stop"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a1/containers/c1/action", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
if resp["command_id"] == "" {
|
||||
t.Error("expected command_id in response")
|
||||
}
|
||||
}
|
||||
125
server/internal/api/auth.go
Normal file
125
server/internal/api/auth.go
Normal file
@ -0,0 +1,125 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func jwtSecret() []byte {
|
||||
if s := os.Getenv("JWT_SECRET"); s != "" {
|
||||
return []byte(s)
|
||||
}
|
||||
return []byte("dev-secret-change-me")
|
||||
}
|
||||
|
||||
type jwtClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
var body struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Username == "" || body.Password == "" {
|
||||
http.Error(w, "invalid body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := h.store.GetUserHash(body.Username)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(hash), []byte(body.Password)) != nil {
|
||||
http.Error(w, "invalid credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwtClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: body.Username,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
})
|
||||
signed, err := token.SignedString(jwtSecret())
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
jsonOK(w, map[string]string{"token": signed})
|
||||
}
|
||||
|
||||
func (h *Handler) ChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := claimsFromContext(r)
|
||||
if !ok {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
CurrentPassword string `json:"current_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.NewPassword == "" {
|
||||
http.Error(w, "invalid body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := h.store.GetUserHash(claims.Subject)
|
||||
if err != nil {
|
||||
http.Error(w, "user not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(hash), []byte(body.CurrentPassword)) != nil {
|
||||
http.Error(w, "mot de passe actuel incorrect", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
newHash, err := bcrypt.GenerateFromPassword([]byte(body.NewPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := h.store.UpsertUser(claims.Subject, string(newHash)); err != nil {
|
||||
http.Error(w, "store error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func requireJWT(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw := extractToken(r)
|
||||
if raw == "" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
t, err := jwt.ParseWithClaims(raw, &jwtClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return jwtSecret(), nil
|
||||
})
|
||||
if err != nil || !t.Valid {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(
|
||||
contextWithClaims(r.Context(), t.Claims.(*jwtClaims)),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
func extractToken(r *http.Request) string {
|
||||
if auth := r.Header.Get("Authorization"); len(auth) > 7 && auth[:7] == "Bearer " {
|
||||
return auth[7:]
|
||||
}
|
||||
return r.URL.Query().Get("token")
|
||||
}
|
||||
16
server/internal/api/context.go
Normal file
16
server/internal/api/context.go
Normal file
@ -0,0 +1,16 @@
|
||||
package api
|
||||
|
||||
import "context"
|
||||
|
||||
type contextKey int
|
||||
|
||||
const claimsKey contextKey = iota
|
||||
|
||||
func contextWithClaims(ctx context.Context, c *jwtClaims) context.Context {
|
||||
return context.WithValue(ctx, claimsKey, c)
|
||||
}
|
||||
|
||||
func claimsFromContext(r interface{ Context() context.Context }) (*jwtClaims, bool) {
|
||||
c, ok := r.Context().Value(claimsKey).(*jwtClaims)
|
||||
return c, ok && c != nil
|
||||
}
|
||||
301
server/internal/api/handlers.go
Normal file
301
server/internal/api/handlers.go
Normal file
@ -0,0 +1,301 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/containarr/server/internal/broker"
|
||||
grpcgateway "github.com/containarr/server/internal/grpc"
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
"github.com/containarr/server/internal/store"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
store *store.Store
|
||||
registry *grpcgateway.Registry
|
||||
broker *broker.Broker
|
||||
}
|
||||
|
||||
func NewHandler(s *store.Store, r *grpcgateway.Registry, b *broker.Broker) *Handler {
|
||||
return &Handler{store: s, registry: r, broker: b}
|
||||
}
|
||||
|
||||
type agentDTO struct {
|
||||
ID string `json:"id"`
|
||||
Hostname string `json:"hostname"`
|
||||
Alias string `json:"alias"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
Arch string `json:"arch"`
|
||||
OS string `json:"os"`
|
||||
Online bool `json:"online"`
|
||||
LastSeenAt time.Time `json:"last_seen_at"`
|
||||
}
|
||||
|
||||
// ── Agents ────────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *Handler) ListAgents(w http.ResponseWriter, r *http.Request) {
|
||||
persisted, err := h.store.ListAgents()
|
||||
if err != nil {
|
||||
http.Error(w, "store error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
liveByID := map[string]*grpcgateway.AgentState{}
|
||||
for _, s := range h.registry.List() {
|
||||
liveByID[s.ID] = s
|
||||
}
|
||||
out := make([]agentDTO, 0, len(persisted))
|
||||
for _, a := range persisted {
|
||||
dto := agentDTO{
|
||||
ID: a.ID,
|
||||
Hostname: a.Hostname,
|
||||
Alias: a.Alias,
|
||||
IPAddress: a.IPAddress,
|
||||
Arch: a.Arch,
|
||||
OS: a.OS,
|
||||
}
|
||||
if live, ok := liveByID[a.ID]; ok {
|
||||
dto.Online = true
|
||||
dto.IPAddress = live.IPAddress
|
||||
dto.LastSeenAt = live.LastSeenAt
|
||||
}
|
||||
out = append(out, dto)
|
||||
}
|
||||
jsonOK(w, out)
|
||||
}
|
||||
|
||||
func (h *Handler) UpdateAgent(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
var body struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, "invalid body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := h.store.UpdateAgentAlias(agentID, body.Alias); err != nil {
|
||||
http.Error(w, "store error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.registry.UpdateAlias(agentID, body.Alias)
|
||||
|
||||
a, err := h.store.GetAgent(agentID)
|
||||
if err != nil {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
jsonOK(w, agentDTO{
|
||||
ID: a.ID,
|
||||
Hostname: a.Hostname,
|
||||
Alias: a.Alias,
|
||||
IPAddress: a.IPAddress,
|
||||
Arch: a.Arch,
|
||||
OS: a.OS,
|
||||
Online: a.Online,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Containers ────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *Handler) ListContainers(w http.ResponseWriter, r *http.Request) {
|
||||
type containerDTO struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Hostname string `json:"hostname"`
|
||||
Alias string `json:"alias"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
Container *agentv1.ContainerInfo `json:"container"`
|
||||
}
|
||||
var out []containerDTO
|
||||
for _, agent := range h.registry.List() {
|
||||
for _, c := range agent.Containers {
|
||||
out = append(out, containerDTO{
|
||||
AgentID: agent.ID,
|
||||
Hostname: agent.Hostname,
|
||||
Alias: agent.Alias,
|
||||
IPAddress: agent.IPAddress,
|
||||
Container: c,
|
||||
})
|
||||
}
|
||||
}
|
||||
jsonOK(w, out)
|
||||
}
|
||||
|
||||
func (h *Handler) ContainerAction(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
containerID := chi.URLParam(r, "containerID")
|
||||
|
||||
var body struct {
|
||||
Action string `json:"action"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, "invalid body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
action, ok := map[string]agentv1.ContainerAction{
|
||||
"start": agentv1.ContainerAction_CONTAINER_ACTION_START,
|
||||
"stop": agentv1.ContainerAction_CONTAINER_ACTION_STOP,
|
||||
"restart": agentv1.ContainerAction_CONTAINER_ACTION_RESTART,
|
||||
"remove": agentv1.ContainerAction_CONTAINER_ACTION_REMOVE,
|
||||
}[body.Action]
|
||||
if !ok {
|
||||
http.Error(w, "unknown action", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cmdID := uuid.NewString()
|
||||
sent := h.registry.Send(agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_ContainerCmd{
|
||||
ContainerCmd: &agentv1.ContainerCommand{
|
||||
CommandId: cmdID,
|
||||
ContainerId: containerID,
|
||||
Action: action,
|
||||
},
|
||||
},
|
||||
})
|
||||
if !sent {
|
||||
http.Error(w, "agent not connected", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
jsonOK(w, map[string]string{"command_id": cmdID})
|
||||
}
|
||||
|
||||
// ── Agent token provisioning ──────────────────────────────────────────────────
|
||||
|
||||
func (h *Handler) CreateAgentToken(w http.ResponseWriter, r *http.Request) {
|
||||
var body struct {
|
||||
Hostname string `json:"hostname"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Hostname == "" {
|
||||
http.Error(w, "hostname required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
id := uuid.NewString()
|
||||
token := uuid.NewString()
|
||||
if err := h.store.CreateAgentToken(id, token, body.Hostname); err != nil {
|
||||
http.Error(w, "store error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
jsonOK(w, map[string]string{"agent_id": id, "token": token})
|
||||
}
|
||||
|
||||
// ── Container log stream ──────────────────────────────────────────────────────
|
||||
|
||||
func (h *Handler) LogsWS(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
containerID := chi.URLParam(r, "containerID")
|
||||
|
||||
follow := r.URL.Query().Get("follow") != "false"
|
||||
tail := int32(100)
|
||||
if t := r.URL.Query().Get("tail"); t != "" {
|
||||
if n, err := strconv.Atoi(t); err == nil && n > 0 {
|
||||
tail = int32(n)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
sent := h.registry.Send(agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_StreamLogs{
|
||||
StreamLogs: &agentv1.StreamLogsCommand{
|
||||
CommandId: uuid.NewString(),
|
||||
ContainerId: containerID,
|
||||
Follow: follow,
|
||||
Tail: tail,
|
||||
},
|
||||
},
|
||||
})
|
||||
if !sent {
|
||||
conn.WriteMessage(websocket.TextMessage, []byte(`{"error":"agent not connected"}`))
|
||||
return
|
||||
}
|
||||
|
||||
sub := h.broker.Subscribe()
|
||||
defer h.broker.Unsubscribe(sub)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case raw, ok := <-sub:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
if json.Unmarshal(raw, &envelope) != nil {
|
||||
continue
|
||||
}
|
||||
if envelope.Type != "log.chunk" || envelope.AgentID != agentID {
|
||||
continue
|
||||
}
|
||||
var chunk struct {
|
||||
ContainerID string `json:"container_id"`
|
||||
Stream string `json:"stream"`
|
||||
Data []byte `json:"data"`
|
||||
}
|
||||
if json.Unmarshal(envelope.Payload, &chunk) != nil {
|
||||
continue
|
||||
}
|
||||
if chunk.ContainerID != containerID {
|
||||
continue
|
||||
}
|
||||
msg, _ := json.Marshal(map[string]string{
|
||||
"stream": chunk.Stream,
|
||||
"line": string(chunk.Data),
|
||||
})
|
||||
if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── WebSocket event stream ────────────────────────────────────────────────────
|
||||
|
||||
func (h *Handler) EventsWS(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
sub := h.broker.Subscribe()
|
||||
defer h.broker.Unsubscribe(sub)
|
||||
|
||||
for data := range sub {
|
||||
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func jsonOK(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
35
server/internal/api/router.go
Normal file
35
server/internal/api/router.go
Normal file
@ -0,0 +1,35 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
func NewRouter(h *Handler) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(middleware.RealIP)
|
||||
|
||||
r.Route("/api/v1", func(r chi.Router) {
|
||||
r.Post("/auth/login", h.Login)
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT)
|
||||
r.Post("/auth/change-password", h.ChangePassword)
|
||||
r.Get("/agents", h.ListAgents)
|
||||
r.Post("/agents/token", h.CreateAgentToken)
|
||||
r.Patch("/agents/{agentID}", h.UpdateAgent)
|
||||
r.Get("/containers", h.ListContainers)
|
||||
r.Post("/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
|
||||
r.Get("/agents/{agentID}/containers/{containerID}/logs", h.LogsWS)
|
||||
r.Get("/events", h.EventsWS)
|
||||
})
|
||||
})
|
||||
|
||||
r.Handle("/*", http.FileServer(http.Dir("./web/dist")))
|
||||
|
||||
return r
|
||||
}
|
||||
49
server/internal/auth/auth.go
Normal file
49
server/internal/auth/auth.go
Normal file
@ -0,0 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
UserID string `json:"uid"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
func New(secret string) *Service {
|
||||
return &Service{secret: []byte(secret)}
|
||||
}
|
||||
|
||||
func (s *Service) Sign(userID string) (string, error) {
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(s.secret)
|
||||
}
|
||||
|
||||
func (s *Service) Verify(tokenStr string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("unexpected signing method")
|
||||
}
|
||||
return s.secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
64
server/internal/auth/auth_test.go
Normal file
64
server/internal/auth/auth_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSignAndVerify(t *testing.T) {
|
||||
svc := New("test-secret")
|
||||
|
||||
token, err := svc.Sign("user42")
|
||||
if err != nil {
|
||||
t.Fatalf("Sign: %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Fatal("expected non-empty token")
|
||||
}
|
||||
|
||||
claims, err := svc.Verify(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Verify: %v", err)
|
||||
}
|
||||
if claims.UserID != "user42" {
|
||||
t.Errorf("expected UserID 'user42', got %q", claims.UserID)
|
||||
}
|
||||
if claims.ExpiresAt == nil || claims.ExpiresAt.Before(time.Now()) {
|
||||
t.Error("token should not be expired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerify_InvalidToken(t *testing.T) {
|
||||
svc := New("test-secret")
|
||||
_, err := svc.Verify("not.a.valid.token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerify_WrongSecret(t *testing.T) {
|
||||
svc1 := New("secret-a")
|
||||
svc2 := New("secret-b")
|
||||
|
||||
token, err := svc1.Sign("user1")
|
||||
if err != nil {
|
||||
t.Fatalf("Sign: %v", err)
|
||||
}
|
||||
|
||||
_, err = svc2.Verify(token)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when verifying with different secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerify_TamperedToken(t *testing.T) {
|
||||
svc := New("test-secret")
|
||||
token, _ := svc.Sign("admin")
|
||||
|
||||
// Append garbage to corrupt the signature.
|
||||
tampered := token + "x"
|
||||
_, err := svc.Verify(tampered)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for tampered token")
|
||||
}
|
||||
}
|
||||
55
server/internal/broker/broker.go
Normal file
55
server/internal/broker/broker.go
Normal file
@ -0,0 +1,55 @@
|
||||
package broker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Event is a JSON-serialisable message pushed to WebSocket clients.
|
||||
type Event struct {
|
||||
Type string `json:"type"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
Payload any `json:"payload"`
|
||||
}
|
||||
|
||||
type subscriber chan []byte
|
||||
|
||||
// Broker fan-outs events to all registered WebSocket subscribers.
|
||||
type Broker struct {
|
||||
mu sync.RWMutex
|
||||
subs map[subscriber]struct{}
|
||||
}
|
||||
|
||||
func New() *Broker {
|
||||
return &Broker{subs: make(map[subscriber]struct{})}
|
||||
}
|
||||
|
||||
func (b *Broker) Subscribe() subscriber {
|
||||
ch := make(subscriber, 32)
|
||||
b.mu.Lock()
|
||||
b.subs[ch] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (b *Broker) Unsubscribe(ch subscriber) {
|
||||
b.mu.Lock()
|
||||
delete(b.subs, ch)
|
||||
b.mu.Unlock()
|
||||
close(ch)
|
||||
}
|
||||
|
||||
func (b *Broker) Publish(evt Event) {
|
||||
data, err := json.Marshal(evt)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
for ch := range b.subs {
|
||||
select {
|
||||
case ch <- data:
|
||||
default: // drop if subscriber is slow
|
||||
}
|
||||
}
|
||||
}
|
||||
123
server/internal/broker/broker_test.go
Normal file
123
server/internal/broker/broker_test.go
Normal file
@ -0,0 +1,123 @@
|
||||
package broker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSubscribePublishUnsubscribe(t *testing.T) {
|
||||
b := New()
|
||||
|
||||
sub := b.Subscribe()
|
||||
|
||||
evt := Event{Type: "test.event", AgentID: "agent1", Payload: map[string]string{"k": "v"}}
|
||||
b.Publish(evt)
|
||||
|
||||
select {
|
||||
case raw := <-sub:
|
||||
var got Event
|
||||
if err := json.Unmarshal(raw, &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if got.Type != "test.event" || got.AgentID != "agent1" {
|
||||
t.Errorf("unexpected event: %+v", got)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for event")
|
||||
}
|
||||
|
||||
b.Unsubscribe(sub)
|
||||
|
||||
// channel must be closed after unsubscribe
|
||||
select {
|
||||
case _, ok := <-sub:
|
||||
if ok {
|
||||
t.Error("expected channel to be closed")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for channel close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleSubscribers(t *testing.T) {
|
||||
b := New()
|
||||
|
||||
sub1 := b.Subscribe()
|
||||
sub2 := b.Subscribe()
|
||||
defer b.Unsubscribe(sub1)
|
||||
defer b.Unsubscribe(sub2)
|
||||
|
||||
b.Publish(Event{Type: "ping", Payload: nil})
|
||||
|
||||
for i, sub := range []subscriber{sub1, sub2} {
|
||||
select {
|
||||
case <-sub:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("subscriber %d did not receive event", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishDropsWhenSubscriberSlow(t *testing.T) {
|
||||
b := New()
|
||||
|
||||
// Channel size is 32; fill it up and then publish one more — it must not block.
|
||||
sub := b.Subscribe()
|
||||
defer b.Unsubscribe(sub)
|
||||
|
||||
// Fill the buffer
|
||||
for i := 0; i < 32; i++ {
|
||||
b.Publish(Event{Type: "flood", Payload: i})
|
||||
}
|
||||
|
||||
// This extra publish must return immediately (dropped, not block).
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
b.Publish(Event{Type: "dropped", Payload: nil})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Publish blocked on slow subscriber")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishNoSubscribers(t *testing.T) {
|
||||
b := New()
|
||||
// Should not panic or block
|
||||
b.Publish(Event{Type: "nobody", Payload: nil})
|
||||
}
|
||||
|
||||
func TestPublishInvalidPayload(t *testing.T) {
|
||||
b := New()
|
||||
sub := b.Subscribe()
|
||||
defer b.Unsubscribe(sub)
|
||||
|
||||
// json.Marshal of a channel fails — Publish must not send anything.
|
||||
b.Publish(Event{Type: "bad", Payload: make(chan int)})
|
||||
|
||||
select {
|
||||
case <-sub:
|
||||
t.Error("should not have received a message for an unmarshalable event")
|
||||
default:
|
||||
// correct: nothing sent
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribeRemovesFromBroker(t *testing.T) {
|
||||
b := New()
|
||||
sub := b.Subscribe()
|
||||
b.Unsubscribe(sub)
|
||||
|
||||
// After unsubscribe the broker's map should be empty.
|
||||
b.mu.RLock()
|
||||
n := len(b.subs)
|
||||
b.mu.RUnlock()
|
||||
|
||||
if n != 0 {
|
||||
t.Errorf("expected 0 subscribers after unsubscribe, got %d", n)
|
||||
}
|
||||
}
|
||||
137
server/internal/grpc/gateway.go
Normal file
137
server/internal/grpc/gateway.go
Normal file
@ -0,0 +1,137 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
|
||||
"github.com/containarr/server/internal/broker"
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
"github.com/containarr/server/internal/store"
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type Gateway struct {
|
||||
agentv1.UnimplementedAgentGatewayServer
|
||||
store *store.Store
|
||||
registry *Registry
|
||||
broker *broker.Broker
|
||||
}
|
||||
|
||||
func NewGateway(s *store.Store, r *Registry, b *broker.Broker) *Gateway {
|
||||
return &Gateway{store: s, registry: r, broker: b}
|
||||
}
|
||||
|
||||
func (g *Gateway) Tunnel(stream agentv1.AgentGateway_TunnelServer) error {
|
||||
if err := stream.SendHeader(metadata.MD{}); err != nil {
|
||||
return status.Errorf(codes.Internal, "send header: %v", err)
|
||||
}
|
||||
|
||||
first, err := stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hs := first.GetHandshake()
|
||||
if hs == nil {
|
||||
return status.Error(codes.InvalidArgument, "first message must be AgentHandshake")
|
||||
}
|
||||
|
||||
existing, err := g.store.AgentByToken(hs.Token)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "unknown agent token")
|
||||
}
|
||||
|
||||
// Extract peer IP from the gRPC connection.
|
||||
ipAddress := ""
|
||||
if p, ok := peer.FromContext(stream.Context()); ok {
|
||||
if host, _, err := net.SplitHostPort(p.Addr.String()); err == nil {
|
||||
ipAddress = host
|
||||
}
|
||||
}
|
||||
|
||||
agentID := existing.ID
|
||||
slog.Info("agent connected", "id", agentID, "hostname", hs.Hostname, "ip", ipAddress)
|
||||
|
||||
state := g.registry.Register(agentID, hs.Hostname, existing.Alias, ipAddress, hs.Arch, hs.Os)
|
||||
_ = g.store.UpsertAgent(&store.Agent{
|
||||
ID: agentID,
|
||||
Token: hs.Token,
|
||||
Hostname: hs.Hostname,
|
||||
Alias: existing.Alias,
|
||||
IPAddress: ipAddress,
|
||||
Arch: hs.Arch,
|
||||
OS: hs.Os,
|
||||
Online: true,
|
||||
})
|
||||
|
||||
g.broker.Publish(broker.Event{
|
||||
Type: "agent.connected",
|
||||
AgentID: agentID,
|
||||
Payload: map[string]string{"hostname": hs.Hostname},
|
||||
})
|
||||
|
||||
defer func() {
|
||||
g.registry.Deregister(agentID)
|
||||
_ = g.store.SetAgentOffline(agentID)
|
||||
g.broker.Publish(broker.Event{Type: "agent.disconnected", AgentID: agentID, Payload: nil})
|
||||
slog.Info("agent disconnected", "id", agentID)
|
||||
}()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
for msg := range state.cmdCh {
|
||||
if err := stream.Send(msg); err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
default:
|
||||
}
|
||||
|
||||
msg, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch p := msg.Payload.(type) {
|
||||
case *agentv1.AgentMessage_Snapshot:
|
||||
g.registry.UpdateContainers(agentID, p.Snapshot.Containers)
|
||||
g.broker.Publish(broker.Event{
|
||||
Type: "containers.updated",
|
||||
AgentID: agentID,
|
||||
Payload: p.Snapshot.Containers,
|
||||
})
|
||||
|
||||
case *agentv1.AgentMessage_Result:
|
||||
g.broker.Publish(broker.Event{
|
||||
Type: "command.result",
|
||||
AgentID: agentID,
|
||||
Payload: p.Result,
|
||||
})
|
||||
|
||||
case *agentv1.AgentMessage_LogChunk:
|
||||
g.broker.Publish(broker.Event{
|
||||
Type: "log.chunk",
|
||||
AgentID: agentID,
|
||||
Payload: p.LogChunk,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newCommandID() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
105
server/internal/grpc/registry.go
Normal file
105
server/internal/grpc/registry.go
Normal file
@ -0,0 +1,105 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
)
|
||||
|
||||
type AgentState struct {
|
||||
ID string
|
||||
Hostname string
|
||||
Alias string
|
||||
IPAddress string
|
||||
Arch string
|
||||
OS string
|
||||
LastSeenAt time.Time
|
||||
Containers []*agentv1.ContainerInfo
|
||||
|
||||
cmdCh chan *agentv1.ServerMessage
|
||||
}
|
||||
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
agents map[string]*AgentState
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{agents: make(map[string]*AgentState)}
|
||||
}
|
||||
|
||||
func (r *Registry) Register(id, hostname, alias, ipAddress, arch, os string) *AgentState {
|
||||
state := &AgentState{
|
||||
ID: id,
|
||||
Hostname: hostname,
|
||||
Alias: alias,
|
||||
IPAddress: ipAddress,
|
||||
Arch: arch,
|
||||
OS: os,
|
||||
cmdCh: make(chan *agentv1.ServerMessage, 16),
|
||||
}
|
||||
r.mu.Lock()
|
||||
r.agents[id] = state
|
||||
r.mu.Unlock()
|
||||
return state
|
||||
}
|
||||
|
||||
func (r *Registry) Deregister(id string) {
|
||||
r.mu.Lock()
|
||||
if s, ok := r.agents[id]; ok {
|
||||
close(s.cmdCh)
|
||||
delete(r.agents, id)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *Registry) Get(id string) (*AgentState, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
s, ok := r.agents[id]
|
||||
return s, ok
|
||||
}
|
||||
|
||||
func (r *Registry) List() []*AgentState {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
out := make([]*AgentState, 0, len(r.agents))
|
||||
for _, s := range r.agents {
|
||||
out = append(out, s)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *Registry) UpdateContainers(id string, containers []*agentv1.ContainerInfo) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if s, ok := r.agents[id]; ok {
|
||||
s.Containers = containers
|
||||
s.LastSeenAt = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateAlias refreshes the alias for a live agent (called after an admin update).
|
||||
func (r *Registry) UpdateAlias(id, alias string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if s, ok := r.agents[id]; ok {
|
||||
s.Alias = alias
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) Send(agentID string, msg *agentv1.ServerMessage) bool {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case s.cmdCh <- msg:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
155
server/internal/grpc/registry_test.go
Normal file
155
server/internal/grpc/registry_test.go
Normal file
@ -0,0 +1,155 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
)
|
||||
|
||||
func TestRegisterAndGet(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
state := r.Register("id1", "hostname1", "alias1", "10.0.0.1", "amd64", "linux")
|
||||
if state == nil {
|
||||
t.Fatal("Register returned nil")
|
||||
}
|
||||
|
||||
got, ok := r.Get("id1")
|
||||
if !ok {
|
||||
t.Fatal("Get returned false for registered agent")
|
||||
}
|
||||
if got.ID != "id1" || got.Hostname != "hostname1" || got.Alias != "alias1" {
|
||||
t.Errorf("unexpected state: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet_NotFound(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_, ok := r.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("expected false for unknown agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeregister(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
r.Deregister("id1")
|
||||
|
||||
_, ok := r.Get("id1")
|
||||
if ok {
|
||||
t.Error("agent should not exist after Deregister")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeregister_NotExist(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
// must not panic
|
||||
r.Deregister("ghost")
|
||||
}
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
if len(r.List()) != 0 {
|
||||
t.Error("expected empty list")
|
||||
}
|
||||
|
||||
r.Register("a1", "h1", "", "", "", "")
|
||||
r.Register("a2", "h2", "", "", "", "")
|
||||
|
||||
if len(r.List()) != 2 {
|
||||
t.Errorf("expected 2 agents, got %d", len(r.List()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateContainers(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
before := time.Now()
|
||||
containers := []*agentv1.ContainerInfo{
|
||||
{Id: "c1", Name: "web"},
|
||||
{Id: "c2", Name: "db"},
|
||||
}
|
||||
r.UpdateContainers("id1", containers)
|
||||
|
||||
got, _ := r.Get("id1")
|
||||
if len(got.Containers) != 2 {
|
||||
t.Errorf("expected 2 containers, got %d", len(got.Containers))
|
||||
}
|
||||
if got.LastSeenAt.Before(before) {
|
||||
t.Error("LastSeenAt should have been updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateContainers_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
// must not panic
|
||||
r.UpdateContainers("ghost", nil)
|
||||
}
|
||||
|
||||
func TestUpdateAlias(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "old-alias", "ip", "arch", "os")
|
||||
|
||||
r.UpdateAlias("id1", "new-alias")
|
||||
|
||||
got, _ := r.Get("id1")
|
||||
if got.Alias != "new-alias" {
|
||||
t.Errorf("expected 'new-alias', got %q", got.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAlias_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
// must not panic
|
||||
r.UpdateAlias("ghost", "alias")
|
||||
}
|
||||
|
||||
func TestSend(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
state := r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
msg := &agentv1.ServerMessage{}
|
||||
ok := r.Send("id1", msg)
|
||||
if !ok {
|
||||
t.Fatal("Send returned false for connected agent")
|
||||
}
|
||||
|
||||
// Drain the channel to verify the message arrived.
|
||||
select {
|
||||
case got := <-state.cmdCh:
|
||||
if got != msg {
|
||||
t.Error("received wrong message")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out reading from cmdCh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
ok := r.Send("ghost", &agentv1.ServerMessage{})
|
||||
if ok {
|
||||
t.Error("Send should return false for unknown agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_FullChannel(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
// Fill the buffer (size 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
r.Send("id1", &agentv1.ServerMessage{})
|
||||
}
|
||||
|
||||
// Next send on a full channel should return false
|
||||
ok := r.Send("id1", &agentv1.ServerMessage{})
|
||||
if ok {
|
||||
t.Error("Send should return false when channel is full")
|
||||
}
|
||||
}
|
||||
183
server/internal/store/store.go
Normal file
183
server/internal/store/store.go
Normal file
@ -0,0 +1,183 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type Agent struct {
|
||||
ID string
|
||||
Token string
|
||||
Hostname string
|
||||
Alias string
|
||||
IPAddress string
|
||||
Arch string
|
||||
OS string
|
||||
LastSeenAt time.Time
|
||||
Online bool
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func New(path string) (*Store, error) {
|
||||
db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_foreign_keys=on")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &Store{db: db}
|
||||
return s, s.migrate()
|
||||
}
|
||||
|
||||
func (s *Store) migrate() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
username TEXT PRIMARY KEY,
|
||||
password_hash TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS agents (
|
||||
id TEXT PRIMARY KEY,
|
||||
token TEXT UNIQUE NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
alias TEXT NOT NULL DEFAULT '',
|
||||
ip_address TEXT NOT NULL DEFAULT '',
|
||||
arch TEXT NOT NULL DEFAULT '',
|
||||
os TEXT NOT NULL DEFAULT '',
|
||||
last_seen_at DATETIME,
|
||||
online INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Idempotent — ignore error if column already exists.
|
||||
for _, col := range []string{
|
||||
`ALTER TABLE agents ADD COLUMN alias TEXT NOT NULL DEFAULT ''`,
|
||||
`ALTER TABLE agents ADD COLUMN ip_address TEXT NOT NULL DEFAULT ''`,
|
||||
} {
|
||||
s.db.Exec(col)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error { return s.db.Close() }
|
||||
|
||||
func (s *Store) UpsertAgent(a *Agent) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO agents (id, token, hostname, alias, ip_address, arch, os, last_seen_at, online)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(token) DO UPDATE SET
|
||||
hostname = excluded.hostname,
|
||||
ip_address = excluded.ip_address,
|
||||
arch = excluded.arch,
|
||||
os = excluded.os,
|
||||
last_seen_at = excluded.last_seen_at,
|
||||
online = excluded.online
|
||||
`, a.ID, a.Token, a.Hostname, a.Alias, a.IPAddress, a.Arch, a.OS, a.LastSeenAt, boolToInt(a.Online))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) AgentByToken(token string) (*Agent, error) {
|
||||
row := s.db.QueryRow(`
|
||||
SELECT id, token, hostname, alias, ip_address, arch, os, last_seen_at, online
|
||||
FROM agents WHERE token = ?`, token)
|
||||
return scanAgent(row)
|
||||
}
|
||||
|
||||
func (s *Store) GetAgent(id string) (*Agent, error) {
|
||||
row := s.db.QueryRow(`
|
||||
SELECT id, token, hostname, alias, ip_address, arch, os, last_seen_at, online
|
||||
FROM agents WHERE id = ?`, id)
|
||||
return scanAgent(row)
|
||||
}
|
||||
|
||||
func (s *Store) ListAgents() ([]*Agent, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, token, hostname, alias, ip_address, arch, os, last_seen_at, online
|
||||
FROM agents ORDER BY hostname`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var agents []*Agent
|
||||
for rows.Next() {
|
||||
a := &Agent{}
|
||||
var online int
|
||||
var lastSeen sql.NullTime
|
||||
if err := rows.Scan(&a.ID, &a.Token, &a.Hostname, &a.Alias, &a.IPAddress, &a.Arch, &a.OS, &lastSeen, &online); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lastSeen.Valid {
|
||||
a.LastSeenAt = lastSeen.Time
|
||||
}
|
||||
a.Online = online == 1
|
||||
agents = append(agents, a)
|
||||
}
|
||||
return agents, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) SetAgentOffline(id string) error {
|
||||
_, err := s.db.Exec(`UPDATE agents SET online = 0 WHERE id = ?`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) CreateAgentToken(id, token, hostname string) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT OR IGNORE INTO agents (id, token, hostname, arch, os, online)
|
||||
VALUES (?, ?, ?, '', '', 0)
|
||||
`, id, token, hostname)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) UpdateAgentAlias(id, alias string) error {
|
||||
_, err := s.db.Exec(`UPDATE agents SET alias = ? WHERE id = ?`, alias, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// ── Users ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
func (s *Store) GetUserHash(username string) (string, error) {
|
||||
var hash string
|
||||
err := s.db.QueryRow(`SELECT password_hash FROM users WHERE username = ?`, username).Scan(&hash)
|
||||
return hash, err
|
||||
}
|
||||
|
||||
func (s *Store) UpsertUser(username, hash string) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO users (username, password_hash) VALUES (?, ?)
|
||||
ON CONFLICT(username) DO UPDATE SET password_hash = excluded.password_hash
|
||||
`, username, hash)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) UserExists(username string) (bool, error) {
|
||||
var n int
|
||||
err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE username = ?`, username).Scan(&n)
|
||||
return n > 0, err
|
||||
}
|
||||
|
||||
func scanAgent(row *sql.Row) (*Agent, error) {
|
||||
a := &Agent{}
|
||||
var online int
|
||||
var lastSeen sql.NullTime
|
||||
err := row.Scan(&a.ID, &a.Token, &a.Hostname, &a.Alias, &a.IPAddress, &a.Arch, &a.OS, &lastSeen, &online)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lastSeen.Valid {
|
||||
a.LastSeenAt = lastSeen.Time
|
||||
}
|
||||
a.Online = online == 1
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
208
server/internal/store/store_test.go
Normal file
208
server/internal/store/store_test.go
Normal file
@ -0,0 +1,208 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newTestStore(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
s, err := New(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open in-memory store: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
// ── Users ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestUpsertAndGetUserHash(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
if err := s.UpsertUser("alice", "hash123"); err != nil {
|
||||
t.Fatalf("UpsertUser: %v", err)
|
||||
}
|
||||
|
||||
h, err := s.GetUserHash("alice")
|
||||
if err != nil {
|
||||
t.Fatalf("GetUserHash: %v", err)
|
||||
}
|
||||
if h != "hash123" {
|
||||
t.Errorf("expected hash123, got %q", h)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserHash_NotFound(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
_, err := s.GetUserHash("nobody")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing user, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertUser_Update(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
if err := s.UpsertUser("alice", "first"); err != nil {
|
||||
t.Fatalf("UpsertUser: %v", err)
|
||||
}
|
||||
if err := s.UpsertUser("alice", "second"); err != nil {
|
||||
t.Fatalf("UpsertUser update: %v", err)
|
||||
}
|
||||
|
||||
h, err := s.GetUserHash("alice")
|
||||
if err != nil {
|
||||
t.Fatalf("GetUserHash: %v", err)
|
||||
}
|
||||
if h != "second" {
|
||||
t.Errorf("expected second, got %q", h)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserExists(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
ok, err := s.UserExists("alice")
|
||||
if err != nil {
|
||||
t.Fatalf("UserExists: %v", err)
|
||||
}
|
||||
if ok {
|
||||
t.Error("expected false for non-existent user")
|
||||
}
|
||||
|
||||
_ = s.UpsertUser("alice", "hash")
|
||||
|
||||
ok, err = s.UserExists("alice")
|
||||
if err != nil {
|
||||
t.Fatalf("UserExists: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Error("expected true after insert")
|
||||
}
|
||||
}
|
||||
|
||||
// ── Agents ────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestCreateAgentToken(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
if err := s.CreateAgentToken("id1", "tok1", "host1"); err != nil {
|
||||
t.Fatalf("CreateAgentToken: %v", err)
|
||||
}
|
||||
|
||||
a, err := s.AgentByToken("tok1")
|
||||
if err != nil {
|
||||
t.Fatalf("AgentByToken: %v", err)
|
||||
}
|
||||
if a.ID != "id1" || a.Hostname != "host1" {
|
||||
t.Errorf("unexpected agent: %+v", a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentByToken_NotFound(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
_, err := s.AgentByToken("doesnotexist")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertAgent(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
a := &Agent{
|
||||
ID: "agent1",
|
||||
Token: "tok1",
|
||||
Hostname: "myhost",
|
||||
Alias: "myalias",
|
||||
IPAddress: "10.0.0.1",
|
||||
Arch: "amd64",
|
||||
OS: "linux",
|
||||
Online: true,
|
||||
}
|
||||
if err := s.UpsertAgent(a); err != nil {
|
||||
t.Fatalf("UpsertAgent: %v", err)
|
||||
}
|
||||
|
||||
got, err := s.GetAgent("agent1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAgent: %v", err)
|
||||
}
|
||||
if got.Hostname != "myhost" || got.Alias != "myalias" {
|
||||
t.Errorf("unexpected agent: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAgents(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
_ = s.CreateAgentToken("a1", "t1", "host-b")
|
||||
_ = s.CreateAgentToken("a2", "t2", "host-a")
|
||||
|
||||
agents, err := s.ListAgents()
|
||||
if err != nil {
|
||||
t.Fatalf("ListAgents: %v", err)
|
||||
}
|
||||
if len(agents) != 2 {
|
||||
t.Fatalf("expected 2 agents, got %d", len(agents))
|
||||
}
|
||||
// ORDER BY hostname: host-a < host-b
|
||||
if agents[0].Hostname != "host-a" || agents[1].Hostname != "host-b" {
|
||||
t.Errorf("unexpected order: %v %v", agents[0].Hostname, agents[1].Hostname)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAgentOffline(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
_ = s.UpsertAgent(&Agent{
|
||||
ID: "a1",
|
||||
Token: "t1",
|
||||
Hostname: "h1",
|
||||
Online: true,
|
||||
})
|
||||
|
||||
if err := s.SetAgentOffline("a1"); err != nil {
|
||||
t.Fatalf("SetAgentOffline: %v", err)
|
||||
}
|
||||
|
||||
a, err := s.GetAgent("a1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAgent: %v", err)
|
||||
}
|
||||
if a.Online {
|
||||
t.Error("expected Online=false after SetAgentOffline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAgentAlias(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
_ = s.CreateAgentToken("a1", "t1", "host1")
|
||||
|
||||
if err := s.UpdateAgentAlias("a1", "newalias"); err != nil {
|
||||
t.Fatalf("UpdateAgentAlias: %v", err)
|
||||
}
|
||||
|
||||
a, err := s.GetAgent("a1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAgent: %v", err)
|
||||
}
|
||||
if a.Alias != "newalias" {
|
||||
t.Errorf("expected alias 'newalias', got %q", a.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAgentToken_IdempotentIgnore(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
// INSERT OR IGNORE — second call should not error
|
||||
if err := s.CreateAgentToken("id1", "tok1", "h1"); err != nil {
|
||||
t.Fatalf("first call: %v", err)
|
||||
}
|
||||
if err := s.CreateAgentToken("id1", "tok1", "h1"); err != nil {
|
||||
t.Fatalf("second call (should be idempotent): %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user