feat: add first page with auth and containers list and agents
This commit is contained in:
550
server/internal/api/api_test.go
Normal file
550
server/internal/api/api_test.go
Normal file
@ -0,0 +1,550 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containarr/server/internal/broker"
|
||||
grpcgateway "github.com/containarr/server/internal/grpc"
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
"github.com/containarr/server/internal/store"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func newTestHandler(t *testing.T) (*Handler, *store.Store, *grpcgateway.Registry, *broker.Broker) {
|
||||
t.Helper()
|
||||
s, err := store.New(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("store.New: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { s.Close() })
|
||||
|
||||
reg := grpcgateway.NewRegistry()
|
||||
b := broker.New()
|
||||
h := NewHandler(s, reg, b)
|
||||
return h, s, reg, b
|
||||
}
|
||||
|
||||
func makeJWT(t *testing.T, subject string) string {
|
||||
t.Helper()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwtClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: subject,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
})
|
||||
signed, err := token.SignedString(jwtSecret())
|
||||
if err != nil {
|
||||
t.Fatalf("makeJWT: %v", err)
|
||||
}
|
||||
return signed
|
||||
}
|
||||
|
||||
func bearerHeader(token string) string {
|
||||
return "Bearer " + token
|
||||
}
|
||||
|
||||
func postJSON(t *testing.T, handler http.HandlerFunc, path string, body any) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
b, _ := json.Marshal(body)
|
||||
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
handler(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
func postJSONAuth(t *testing.T, handler http.HandlerFunc, path string, body any, token string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
b, _ := json.Marshal(body)
|
||||
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", bearerHeader(token))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Wrap handler with requireJWT so claims land in context.
|
||||
requireJWT(handler).ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
// ── extractToken ──────────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractToken_BearerHeader(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer mytoken")
|
||||
if got := extractToken(req); got != "mytoken" {
|
||||
t.Errorf("expected 'mytoken', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToken_QueryParam(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/?token=querytoken", nil)
|
||||
if got := extractToken(req); got != "querytoken" {
|
||||
t.Errorf("expected 'querytoken', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToken_Empty(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if got := extractToken(req); got != "" {
|
||||
t.Errorf("expected empty, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToken_ShortAuthHeader(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bear") // len < 7
|
||||
if got := extractToken(req); got != "" {
|
||||
t.Errorf("expected empty for short header, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ── requireJWT middleware ─────────────────────────────────────────────────────
|
||||
|
||||
func TestRequireJWT_MissingToken(t *testing.T) {
|
||||
called := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
requireJWT(next).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
if called {
|
||||
t.Error("handler should not be called without token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireJWT_InvalidToken(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer not.a.real.token")
|
||||
w := httptest.NewRecorder()
|
||||
requireJWT(next).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireJWT_ValidToken(t *testing.T) {
|
||||
token := makeJWT(t, "alice")
|
||||
called := false
|
||||
var gotSubject string
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
c, ok := claimsFromContext(r)
|
||||
if !ok {
|
||||
t.Error("claims not in context")
|
||||
return
|
||||
}
|
||||
gotSubject = c.Subject
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", bearerHeader(token))
|
||||
w := httptest.NewRecorder()
|
||||
requireJWT(next).ServeHTTP(w, req)
|
||||
|
||||
if !called {
|
||||
t.Error("handler was not called")
|
||||
}
|
||||
if gotSubject != "alice" {
|
||||
t.Errorf("expected subject 'alice', got %q", gotSubject)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireJWT_WrongSecret(t *testing.T) {
|
||||
// Sign with a different secret
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwtClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: "hacker",
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
})
|
||||
signed, _ := token.SignedString([]byte("wrong-secret"))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", bearerHeader(signed))
|
||||
w := httptest.NewRecorder()
|
||||
requireJWT(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})).ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Login ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestLogin_Success(t *testing.T) {
|
||||
h, s, _, _ := newTestHandler(t)
|
||||
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.MinCost)
|
||||
_ = s.UpsertUser("alice", string(hash))
|
||||
|
||||
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
|
||||
"username": "alice",
|
||||
"password": "password",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if resp["token"] == "" {
|
||||
t.Error("expected non-empty token in response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin_WrongPassword(t *testing.T) {
|
||||
h, s, _, _ := newTestHandler(t)
|
||||
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte("correct"), bcrypt.MinCost)
|
||||
_ = s.UpsertUser("alice", string(hash))
|
||||
|
||||
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
|
||||
"username": "alice",
|
||||
"password": "wrong",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin_UnknownUser(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
|
||||
"username": "nobody",
|
||||
"password": "pass",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin_BadBody(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader("not-json"))
|
||||
w := httptest.NewRecorder()
|
||||
h.Login(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin_EmptyFields(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
w := postJSON(t, h.Login, "/api/v1/auth/login", map[string]string{
|
||||
"username": "",
|
||||
"password": "",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── ChangePassword ─────────────────────────────────────────────────────────────
|
||||
|
||||
func TestChangePassword_Success(t *testing.T) {
|
||||
h, s, _, _ := newTestHandler(t)
|
||||
os.Setenv("JWT_SECRET", "test-secret-change-pw")
|
||||
defer os.Unsetenv("JWT_SECRET")
|
||||
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte("oldpass"), bcrypt.MinCost)
|
||||
_ = s.UpsertUser("alice", string(hash))
|
||||
|
||||
token := makeJWT(t, "alice")
|
||||
w := postJSONAuth(t, h.ChangePassword, "/api/v1/auth/change-password", map[string]string{
|
||||
"current_password": "oldpass",
|
||||
"new_password": "newpass",
|
||||
}, token)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Fatalf("expected 204, got %d — body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify new hash is stored
|
||||
newHash, _ := s.GetUserHash("alice")
|
||||
if bcrypt.CompareHashAndPassword([]byte(newHash), []byte("newpass")) != nil {
|
||||
t.Error("new password hash does not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangePassword_WrongCurrentPassword(t *testing.T) {
|
||||
h, s, _, _ := newTestHandler(t)
|
||||
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte("correct"), bcrypt.MinCost)
|
||||
_ = s.UpsertUser("alice", string(hash))
|
||||
|
||||
token := makeJWT(t, "alice")
|
||||
w := postJSONAuth(t, h.ChangePassword, "/api/v1/auth/change-password", map[string]string{
|
||||
"current_password": "wrong",
|
||||
"new_password": "newpass",
|
||||
}, token)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── ListAgents ────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestListAgents_Empty(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ListAgents(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var agents []agentDTO
|
||||
if err := json.NewDecoder(w.Body).Decode(&agents); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(agents) != 0 {
|
||||
t.Errorf("expected empty list, got %d", len(agents))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAgents_PersistenceAndLive(t *testing.T) {
|
||||
h, s, reg, _ := newTestHandler(t)
|
||||
|
||||
_ = s.CreateAgentToken("a1", "t1", "host1")
|
||||
// Register a2 in the registry (simulating live agent)
|
||||
reg.Register("a2", "host2", "alias2", "192.168.1.1", "arm64", "linux")
|
||||
_ = s.CreateAgentToken("a2", "t2", "host2")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ListAgents(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var agents []agentDTO
|
||||
json.NewDecoder(w.Body).Decode(&agents)
|
||||
|
||||
if len(agents) != 2 {
|
||||
t.Fatalf("expected 2 agents, got %d", len(agents))
|
||||
}
|
||||
|
||||
// Find a2 — it should be online
|
||||
var a2 *agentDTO
|
||||
for i := range agents {
|
||||
if agents[i].ID == "a2" {
|
||||
a2 = &agents[i]
|
||||
}
|
||||
}
|
||||
if a2 == nil {
|
||||
t.Fatal("a2 not found in list")
|
||||
}
|
||||
if !a2.Online {
|
||||
t.Error("a2 should be online (registered in registry)")
|
||||
}
|
||||
}
|
||||
|
||||
// ── CreateAgentToken ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestCreateAgentToken_Success(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
b, _ := json.Marshal(map[string]string{"hostname": "new-agent"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/token", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateAgentToken(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
|
||||
if resp["agent_id"] == "" || resp["token"] == "" {
|
||||
t.Errorf("missing agent_id or token in response: %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAgentToken_MissingHostname(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
b, _ := json.Marshal(map[string]string{"hostname": ""})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/token", bytes.NewReader(b))
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateAgentToken(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── UpdateAgent ───────────────────────────────────────────────────────────────
|
||||
|
||||
func TestUpdateAgent_Success(t *testing.T) {
|
||||
h, s, reg, _ := newTestHandler(t)
|
||||
|
||||
_ = s.CreateAgentToken("a1", "t1", "host1")
|
||||
reg.Register("a1", "host1", "old", "ip", "arch", "os")
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"alias": "new-alias"})
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Patch("/api/v1/agents/{agentID}", h.UpdateAgent)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPatch, "/api/v1/agents/a1", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp agentDTO
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
if resp.Alias != "new-alias" {
|
||||
t.Errorf("expected alias 'new-alias', got %q", resp.Alias)
|
||||
}
|
||||
|
||||
// Confirm registry also updated
|
||||
state, ok := reg.Get("a1")
|
||||
if !ok {
|
||||
t.Fatal("agent not in registry")
|
||||
}
|
||||
if state.Alias != "new-alias" {
|
||||
t.Errorf("registry alias not updated, got %q", state.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
// ── ListContainers ─────────────────────────────────────────────────────────────
|
||||
|
||||
func TestListContainers_Empty(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/containers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ListContainers(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListContainers_WithData(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
|
||||
reg.Register("a1", "host1", "alias1", "10.0.0.1", "amd64", "linux")
|
||||
reg.UpdateContainers("a1", []*agentv1.ContainerInfo{
|
||||
{Id: "c1", Name: "web"},
|
||||
{Id: "c2", Name: "db"},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/containers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ListContainers(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var out []struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Container *agentv1.ContainerInfo `json:"container"`
|
||||
}
|
||||
json.NewDecoder(w.Body).Decode(&out)
|
||||
|
||||
if len(out) != 2 {
|
||||
t.Errorf("expected 2 containers, got %d", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// ── ContainerAction ───────────────────────────────────────────────────────────
|
||||
|
||||
func TestContainerAction_AgentNotConnected(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"action": "start"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/ghost/containers/c1/action", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainerAction_InvalidAction(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
reg.Register("a1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"action": "explode"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a1/containers/c1/action", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainerAction_Success(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
reg.Register("a1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"action": "stop"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a1/containers/c1/action", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d — body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
if resp["command_id"] == "" {
|
||||
t.Error("expected command_id in response")
|
||||
}
|
||||
}
|
||||
125
server/internal/api/auth.go
Normal file
125
server/internal/api/auth.go
Normal file
@ -0,0 +1,125 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func jwtSecret() []byte {
|
||||
if s := os.Getenv("JWT_SECRET"); s != "" {
|
||||
return []byte(s)
|
||||
}
|
||||
return []byte("dev-secret-change-me")
|
||||
}
|
||||
|
||||
type jwtClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
var body struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Username == "" || body.Password == "" {
|
||||
http.Error(w, "invalid body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := h.store.GetUserHash(body.Username)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(hash), []byte(body.Password)) != nil {
|
||||
http.Error(w, "invalid credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwtClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: body.Username,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
})
|
||||
signed, err := token.SignedString(jwtSecret())
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
jsonOK(w, map[string]string{"token": signed})
|
||||
}
|
||||
|
||||
func (h *Handler) ChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := claimsFromContext(r)
|
||||
if !ok {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
CurrentPassword string `json:"current_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.NewPassword == "" {
|
||||
http.Error(w, "invalid body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := h.store.GetUserHash(claims.Subject)
|
||||
if err != nil {
|
||||
http.Error(w, "user not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(hash), []byte(body.CurrentPassword)) != nil {
|
||||
http.Error(w, "mot de passe actuel incorrect", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
newHash, err := bcrypt.GenerateFromPassword([]byte(body.NewPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := h.store.UpsertUser(claims.Subject, string(newHash)); err != nil {
|
||||
http.Error(w, "store error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func requireJWT(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw := extractToken(r)
|
||||
if raw == "" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
t, err := jwt.ParseWithClaims(raw, &jwtClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return jwtSecret(), nil
|
||||
})
|
||||
if err != nil || !t.Valid {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(
|
||||
contextWithClaims(r.Context(), t.Claims.(*jwtClaims)),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
func extractToken(r *http.Request) string {
|
||||
if auth := r.Header.Get("Authorization"); len(auth) > 7 && auth[:7] == "Bearer " {
|
||||
return auth[7:]
|
||||
}
|
||||
return r.URL.Query().Get("token")
|
||||
}
|
||||
16
server/internal/api/context.go
Normal file
16
server/internal/api/context.go
Normal file
@ -0,0 +1,16 @@
|
||||
package api
|
||||
|
||||
import "context"
|
||||
|
||||
type contextKey int
|
||||
|
||||
const claimsKey contextKey = iota
|
||||
|
||||
func contextWithClaims(ctx context.Context, c *jwtClaims) context.Context {
|
||||
return context.WithValue(ctx, claimsKey, c)
|
||||
}
|
||||
|
||||
func claimsFromContext(r interface{ Context() context.Context }) (*jwtClaims, bool) {
|
||||
c, ok := r.Context().Value(claimsKey).(*jwtClaims)
|
||||
return c, ok && c != nil
|
||||
}
|
||||
301
server/internal/api/handlers.go
Normal file
301
server/internal/api/handlers.go
Normal file
@ -0,0 +1,301 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/containarr/server/internal/broker"
|
||||
grpcgateway "github.com/containarr/server/internal/grpc"
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
"github.com/containarr/server/internal/store"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
store *store.Store
|
||||
registry *grpcgateway.Registry
|
||||
broker *broker.Broker
|
||||
}
|
||||
|
||||
func NewHandler(s *store.Store, r *grpcgateway.Registry, b *broker.Broker) *Handler {
|
||||
return &Handler{store: s, registry: r, broker: b}
|
||||
}
|
||||
|
||||
type agentDTO struct {
|
||||
ID string `json:"id"`
|
||||
Hostname string `json:"hostname"`
|
||||
Alias string `json:"alias"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
Arch string `json:"arch"`
|
||||
OS string `json:"os"`
|
||||
Online bool `json:"online"`
|
||||
LastSeenAt time.Time `json:"last_seen_at"`
|
||||
}
|
||||
|
||||
// ── Agents ────────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *Handler) ListAgents(w http.ResponseWriter, r *http.Request) {
|
||||
persisted, err := h.store.ListAgents()
|
||||
if err != nil {
|
||||
http.Error(w, "store error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
liveByID := map[string]*grpcgateway.AgentState{}
|
||||
for _, s := range h.registry.List() {
|
||||
liveByID[s.ID] = s
|
||||
}
|
||||
out := make([]agentDTO, 0, len(persisted))
|
||||
for _, a := range persisted {
|
||||
dto := agentDTO{
|
||||
ID: a.ID,
|
||||
Hostname: a.Hostname,
|
||||
Alias: a.Alias,
|
||||
IPAddress: a.IPAddress,
|
||||
Arch: a.Arch,
|
||||
OS: a.OS,
|
||||
}
|
||||
if live, ok := liveByID[a.ID]; ok {
|
||||
dto.Online = true
|
||||
dto.IPAddress = live.IPAddress
|
||||
dto.LastSeenAt = live.LastSeenAt
|
||||
}
|
||||
out = append(out, dto)
|
||||
}
|
||||
jsonOK(w, out)
|
||||
}
|
||||
|
||||
func (h *Handler) UpdateAgent(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
var body struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, "invalid body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := h.store.UpdateAgentAlias(agentID, body.Alias); err != nil {
|
||||
http.Error(w, "store error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.registry.UpdateAlias(agentID, body.Alias)
|
||||
|
||||
a, err := h.store.GetAgent(agentID)
|
||||
if err != nil {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
jsonOK(w, agentDTO{
|
||||
ID: a.ID,
|
||||
Hostname: a.Hostname,
|
||||
Alias: a.Alias,
|
||||
IPAddress: a.IPAddress,
|
||||
Arch: a.Arch,
|
||||
OS: a.OS,
|
||||
Online: a.Online,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Containers ────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *Handler) ListContainers(w http.ResponseWriter, r *http.Request) {
|
||||
type containerDTO struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Hostname string `json:"hostname"`
|
||||
Alias string `json:"alias"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
Container *agentv1.ContainerInfo `json:"container"`
|
||||
}
|
||||
var out []containerDTO
|
||||
for _, agent := range h.registry.List() {
|
||||
for _, c := range agent.Containers {
|
||||
out = append(out, containerDTO{
|
||||
AgentID: agent.ID,
|
||||
Hostname: agent.Hostname,
|
||||
Alias: agent.Alias,
|
||||
IPAddress: agent.IPAddress,
|
||||
Container: c,
|
||||
})
|
||||
}
|
||||
}
|
||||
jsonOK(w, out)
|
||||
}
|
||||
|
||||
func (h *Handler) ContainerAction(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
containerID := chi.URLParam(r, "containerID")
|
||||
|
||||
var body struct {
|
||||
Action string `json:"action"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, "invalid body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
action, ok := map[string]agentv1.ContainerAction{
|
||||
"start": agentv1.ContainerAction_CONTAINER_ACTION_START,
|
||||
"stop": agentv1.ContainerAction_CONTAINER_ACTION_STOP,
|
||||
"restart": agentv1.ContainerAction_CONTAINER_ACTION_RESTART,
|
||||
"remove": agentv1.ContainerAction_CONTAINER_ACTION_REMOVE,
|
||||
}[body.Action]
|
||||
if !ok {
|
||||
http.Error(w, "unknown action", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cmdID := uuid.NewString()
|
||||
sent := h.registry.Send(agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_ContainerCmd{
|
||||
ContainerCmd: &agentv1.ContainerCommand{
|
||||
CommandId: cmdID,
|
||||
ContainerId: containerID,
|
||||
Action: action,
|
||||
},
|
||||
},
|
||||
})
|
||||
if !sent {
|
||||
http.Error(w, "agent not connected", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
jsonOK(w, map[string]string{"command_id": cmdID})
|
||||
}
|
||||
|
||||
// ── Agent token provisioning ──────────────────────────────────────────────────
|
||||
|
||||
func (h *Handler) CreateAgentToken(w http.ResponseWriter, r *http.Request) {
|
||||
var body struct {
|
||||
Hostname string `json:"hostname"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Hostname == "" {
|
||||
http.Error(w, "hostname required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
id := uuid.NewString()
|
||||
token := uuid.NewString()
|
||||
if err := h.store.CreateAgentToken(id, token, body.Hostname); err != nil {
|
||||
http.Error(w, "store error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
jsonOK(w, map[string]string{"agent_id": id, "token": token})
|
||||
}
|
||||
|
||||
// ── Container log stream ──────────────────────────────────────────────────────
|
||||
|
||||
func (h *Handler) LogsWS(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
containerID := chi.URLParam(r, "containerID")
|
||||
|
||||
follow := r.URL.Query().Get("follow") != "false"
|
||||
tail := int32(100)
|
||||
if t := r.URL.Query().Get("tail"); t != "" {
|
||||
if n, err := strconv.Atoi(t); err == nil && n > 0 {
|
||||
tail = int32(n)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
sent := h.registry.Send(agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_StreamLogs{
|
||||
StreamLogs: &agentv1.StreamLogsCommand{
|
||||
CommandId: uuid.NewString(),
|
||||
ContainerId: containerID,
|
||||
Follow: follow,
|
||||
Tail: tail,
|
||||
},
|
||||
},
|
||||
})
|
||||
if !sent {
|
||||
conn.WriteMessage(websocket.TextMessage, []byte(`{"error":"agent not connected"}`))
|
||||
return
|
||||
}
|
||||
|
||||
sub := h.broker.Subscribe()
|
||||
defer h.broker.Unsubscribe(sub)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case raw, ok := <-sub:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
if json.Unmarshal(raw, &envelope) != nil {
|
||||
continue
|
||||
}
|
||||
if envelope.Type != "log.chunk" || envelope.AgentID != agentID {
|
||||
continue
|
||||
}
|
||||
var chunk struct {
|
||||
ContainerID string `json:"container_id"`
|
||||
Stream string `json:"stream"`
|
||||
Data []byte `json:"data"`
|
||||
}
|
||||
if json.Unmarshal(envelope.Payload, &chunk) != nil {
|
||||
continue
|
||||
}
|
||||
if chunk.ContainerID != containerID {
|
||||
continue
|
||||
}
|
||||
msg, _ := json.Marshal(map[string]string{
|
||||
"stream": chunk.Stream,
|
||||
"line": string(chunk.Data),
|
||||
})
|
||||
if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── WebSocket event stream ────────────────────────────────────────────────────
|
||||
|
||||
func (h *Handler) EventsWS(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
sub := h.broker.Subscribe()
|
||||
defer h.broker.Unsubscribe(sub)
|
||||
|
||||
for data := range sub {
|
||||
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func jsonOK(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
35
server/internal/api/router.go
Normal file
35
server/internal/api/router.go
Normal file
@ -0,0 +1,35 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
func NewRouter(h *Handler) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(middleware.RealIP)
|
||||
|
||||
r.Route("/api/v1", func(r chi.Router) {
|
||||
r.Post("/auth/login", h.Login)
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT)
|
||||
r.Post("/auth/change-password", h.ChangePassword)
|
||||
r.Get("/agents", h.ListAgents)
|
||||
r.Post("/agents/token", h.CreateAgentToken)
|
||||
r.Patch("/agents/{agentID}", h.UpdateAgent)
|
||||
r.Get("/containers", h.ListContainers)
|
||||
r.Post("/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
|
||||
r.Get("/agents/{agentID}/containers/{containerID}/logs", h.LogsWS)
|
||||
r.Get("/events", h.EventsWS)
|
||||
})
|
||||
})
|
||||
|
||||
r.Handle("/*", http.FileServer(http.Dir("./web/dist")))
|
||||
|
||||
return r
|
||||
}
|
||||
Reference in New Issue
Block a user