package api import ( "bytes" "encoding/json" "net/http" "net/http/httptest" "os" "strings" "testing" "time" "github.com/containarr/server/internal/broker" grpcgateway "github.com/containarr/server/internal/grpc" agentv1 "github.com/containarr/server/internal/proto/agentv1" "github.com/containarr/server/internal/store" "github.com/go-chi/chi/v5" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" ) // ── helpers ─────────────────────────────────────────────────────────────────── func newTestHandler(t *testing.T) (*Handler, *store.Store, *grpcgateway.Registry, *broker.Broker) { t.Helper() s, err := store.New(":memory:") if err != nil { t.Fatalf("store.New: %v", err) } t.Cleanup(func() { s.Close() }) reg := grpcgateway.NewRegistry() b := broker.New() h := NewHandler(s, reg, b) return h, s, reg, b } func makeJWT(t *testing.T, subject string) string { t.Helper() token := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwtClaims{ RegisteredClaims: jwt.RegisteredClaims{ Subject: subject, ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), }, }) signed, err := token.SignedString(jwtSecret()) if err != nil { t.Fatalf("makeJWT: %v", err) } return signed } func bearerHeader(token string) string { return "Bearer " + token } func postJSON(t *testing.T, handler http.HandlerFunc, path string, body any) *httptest.ResponseRecorder { t.Helper() b, _ := json.Marshal(body) req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(b)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() handler(w, req) return w } func postJSONAuth(t *testing.T, handler http.HandlerFunc, path string, body any, token string) *httptest.ResponseRecorder { t.Helper() b, _ := json.Marshal(body) req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(b)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerHeader(token)) w := httptest.NewRecorder() // Wrap handler with requireJWT so claims land in context. requireJWT(handler).ServeHTTP(w, req) return w } // ── extractToken ────────────────────────────────────────────────────────────── func TestExtractToken_BearerHeader(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer mytoken") if got := extractToken(req); got != "mytoken" { t.Errorf("expected 'mytoken', got %q", got) } } func TestExtractToken_QueryParam(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/?token=querytoken", nil) if got := extractToken(req); got != "querytoken" { t.Errorf("expected 'querytoken', got %q", got) } } func TestExtractToken_Empty(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) if got := extractToken(req); got != "" { t.Errorf("expected empty, got %q", got) } } func TestExtractToken_ShortAuthHeader(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bear") // len < 7 if got := extractToken(req); got != "" { t.Errorf("expected empty for short header, got %q", got) } } // ── requireJWT middleware ───────────────────────────────────────────────────── func TestRequireJWT_MissingToken(t *testing.T) { called := false next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true }) req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() requireJWT(next).ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", w.Code) } if called { t.Error("handler should not be called without token") } } func TestRequireJWT_InvalidToken(t *testing.T) { next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer not.a.real.token") w := httptest.NewRecorder() requireJWT(next).ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", w.Code) } } func TestRequireJWT_ValidToken(t *testing.T) { token := makeJWT(t, "alice") called := false var gotSubject string next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true c, ok := claimsFromContext(r) if !ok { t.Error("claims not in context") return } gotSubject = c.Subject }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", bearerHeader(token)) w := httptest.NewRecorder() requireJWT(next).ServeHTTP(w, req) if !called { t.Error("handler was not called") } if gotSubject != "alice" { t.Errorf("expected subject 'alice', got %q", gotSubject) } } func TestRequireJWT_WrongSecret(t *testing.T) { // Sign with a different secret token := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwtClaims{ RegisteredClaims: jwt.RegisteredClaims{ Subject: "hacker", ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), }, }) signed, _ := token.SignedString([]byte("wrong-secret")) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", bearerHeader(signed)) w := httptest.NewRecorder() requireJWT(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})).ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", w.Code) } } // ── Login ───────────────────────────────────────────────────────────────────── func TestLogin_Success(t *testing.T) { h, s, _, _ := newTestHandler(t) hash, _ := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.MinCost) _ = s.UpsertUser("alice", string(hash)) w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{ "username": "alice", "password": "password", }) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String()) } var resp map[string]string if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("decode response: %v", err) } if resp["token"] == "" { t.Error("expected non-empty token in response") } } func TestLogin_WrongPassword(t *testing.T) { h, s, _, _ := newTestHandler(t) hash, _ := bcrypt.GenerateFromPassword([]byte("correct"), bcrypt.MinCost) _ = s.UpsertUser("alice", string(hash)) w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{ "username": "alice", "password": "wrong", }) if w.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", w.Code) } } func TestLogin_UnknownUser(t *testing.T) { h, _, _, _ := newTestHandler(t) w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{ "username": "nobody", "password": "pass", }) if w.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", w.Code) } } func TestLogin_BadBody(t *testing.T) { h, _, _, _ := newTestHandler(t) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader("not-json")) w := httptest.NewRecorder() h.Login(w, req) if w.Code != http.StatusBadRequest { t.Errorf("expected 400, got %d", w.Code) } } func TestLogin_EmptyFields(t *testing.T) { h, _, _, _ := newTestHandler(t) w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{ "username": "", "password": "", }) if w.Code != http.StatusBadRequest { t.Errorf("expected 400, got %d", w.Code) } } // ── ChangePassword ───────────────────────────────────────────────────────────── func TestChangePassword_Success(t *testing.T) { h, s, _, _ := newTestHandler(t) os.Setenv("JWT_SECRET", "test-secret-change-pw") defer os.Unsetenv("JWT_SECRET") hash, _ := bcrypt.GenerateFromPassword([]byte("oldpass"), bcrypt.MinCost) _ = s.UpsertUser("alice", string(hash)) token := makeJWT(t, "alice") w := postJSONAuth(t, h.ChangePassword, "/api/v1/auth/change-password", map[string]string{ "current_password": "oldpass", "new_password": "newpass", }, token) if w.Code != http.StatusNoContent { t.Fatalf("expected 204, got %d — body: %s", w.Code, w.Body.String()) } // Verify new hash is stored newHash, _ := s.GetUserHash("alice") if bcrypt.CompareHashAndPassword([]byte(newHash), []byte("newpass")) != nil { t.Error("new password hash does not match") } } func TestChangePassword_WrongCurrentPassword(t *testing.T) { h, s, _, _ := newTestHandler(t) hash, _ := bcrypt.GenerateFromPassword([]byte("correct"), bcrypt.MinCost) _ = s.UpsertUser("alice", string(hash)) token := makeJWT(t, "alice") w := postJSONAuth(t, h.ChangePassword, "/api/v1/auth/change-password", map[string]string{ "current_password": "wrong", "new_password": "newpass", }, token) if w.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", w.Code) } } // ── ListAgents ──────────────────────────────────────────────────────────────── func TestListAgents_Empty(t *testing.T) { h, _, _, _ := newTestHandler(t) req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil) w := httptest.NewRecorder() h.ListAgents(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } var agents []agentDTO if err := json.NewDecoder(w.Body).Decode(&agents); err != nil { t.Fatalf("decode: %v", err) } if len(agents) != 0 { t.Errorf("expected empty list, got %d", len(agents)) } } func TestListAgents_PersistenceAndLive(t *testing.T) { h, s, reg, _ := newTestHandler(t) _ = s.CreateAgentToken("a1", "t1", "host1") // Register a2 in the registry (simulating live agent) reg.Register("a2", "host2", "alias2", "192.168.1.1", "arm64", "linux") _ = s.CreateAgentToken("a2", "t2", "host2") req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil) w := httptest.NewRecorder() h.ListAgents(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } var agents []agentDTO json.NewDecoder(w.Body).Decode(&agents) if len(agents) != 2 { t.Fatalf("expected 2 agents, got %d", len(agents)) } // Find a2 — it should be online var a2 *agentDTO for i := range agents { if agents[i].ID == "a2" { a2 = &agents[i] } } if a2 == nil { t.Fatal("a2 not found in list") } if !a2.Online { t.Error("a2 should be online (registered in registry)") } } // ── CreateAgentToken ────────────────────────────────────────────────────────── func TestCreateAgentToken_Success(t *testing.T) { h, _, _, _ := newTestHandler(t) b, _ := json.Marshal(map[string]string{"hostname": "new-agent"}) req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/token", bytes.NewReader(b)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() h.CreateAgentToken(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String()) } var resp map[string]string json.NewDecoder(w.Body).Decode(&resp) if resp["agent_id"] == "" || resp["token"] == "" { t.Errorf("missing agent_id or token in response: %v", resp) } } func TestCreateAgentToken_MissingHostname(t *testing.T) { h, _, _, _ := newTestHandler(t) b, _ := json.Marshal(map[string]string{"hostname": ""}) req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/token", bytes.NewReader(b)) w := httptest.NewRecorder() h.CreateAgentToken(w, req) if w.Code != http.StatusBadRequest { t.Errorf("expected 400, got %d", w.Code) } } // ── UpdateAgent ─────────────────────────────────────────────────────────────── func TestUpdateAgent_Success(t *testing.T) { h, s, reg, _ := newTestHandler(t) _ = s.CreateAgentToken("a1", "t1", "host1") reg.Register("a1", "host1", "old", "ip", "arch", "os") body, _ := json.Marshal(map[string]string{"alias": "new-alias"}) router := chi.NewRouter() router.Patch("/api/v1/agents/{agentID}", h.UpdateAgent) req, _ := http.NewRequest(http.MethodPatch, "/api/v1/agents/a1", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String()) } var resp agentDTO json.NewDecoder(w.Body).Decode(&resp) if resp.Alias != "new-alias" { t.Errorf("expected alias 'new-alias', got %q", resp.Alias) } // Confirm registry also updated state, ok := reg.Get("a1") if !ok { t.Fatal("agent not in registry") } if state.Alias != "new-alias" { t.Errorf("registry alias not updated, got %q", state.Alias) } } // ── ListContainers ───────────────────────────────────────────────────────────── func TestListContainers_Empty(t *testing.T) { h, _, _, _ := newTestHandler(t) req := httptest.NewRequest(http.MethodGet, "/api/v1/containers", nil) w := httptest.NewRecorder() h.ListContainers(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } } func TestListContainers_WithData(t *testing.T) { h, _, reg, _ := newTestHandler(t) reg.Register("a1", "host1", "alias1", "10.0.0.1", "amd64", "linux") reg.UpdateContainers("a1", []*agentv1.ContainerInfo{ {Id: "c1", Name: "web"}, {Id: "c2", Name: "db"}, }) req := httptest.NewRequest(http.MethodGet, "/api/v1/containers", nil) w := httptest.NewRecorder() h.ListContainers(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } var out []struct { AgentID string `json:"agent_id"` Container *agentv1.ContainerInfo `json:"container"` } json.NewDecoder(w.Body).Decode(&out) if len(out) != 2 { t.Errorf("expected 2 containers, got %d", len(out)) } } // ── ContainerAction ─────────────────────────────────────────────────────────── func TestContainerAction_AgentNotConnected(t *testing.T) { h, _, _, _ := newTestHandler(t) body, _ := json.Marshal(map[string]string{"action": "start"}) req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/ghost/containers/c1/action", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") router := chi.NewRouter() router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusServiceUnavailable { t.Errorf("expected 503, got %d", w.Code) } } func TestContainerAction_InvalidAction(t *testing.T) { h, _, reg, _ := newTestHandler(t) reg.Register("a1", "h", "a", "ip", "arch", "os") body, _ := json.Marshal(map[string]string{"action": "explode"}) req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a1/containers/c1/action", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") router := chi.NewRouter() router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Errorf("expected 400, got %d", w.Code) } } func TestContainerAction_Success(t *testing.T) { h, _, reg, _ := newTestHandler(t) reg.Register("a1", "h", "a", "ip", "arch", "os") body, _ := json.Marshal(map[string]string{"action": "stop"}) req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a1/containers/c1/action", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") router := chi.NewRouter() router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String()) } var resp map[string]string json.NewDecoder(w.Body).Decode(&resp) if resp["command_id"] == "" { t.Error("expected command_id in response") } }