feat: add auto update
This commit is contained in:
@ -18,6 +18,7 @@ import (
|
||||
"github.com/containarr/server/internal/broker"
|
||||
grpcgateway "github.com/containarr/server/internal/grpc"
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
"github.com/containarr/server/internal/scheduler"
|
||||
"github.com/containarr/server/internal/store"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
@ -39,6 +40,14 @@ func main() {
|
||||
reg := grpcgateway.NewRegistry()
|
||||
brk := broker.New()
|
||||
|
||||
// Root context cancelled on shutdown signal.
|
||||
rootCtx, rootCancel := context.WithCancel(context.Background())
|
||||
defer rootCancel()
|
||||
|
||||
// Scheduler.
|
||||
sched := scheduler.New(scheduler.NewStoreAdapter(db), reg)
|
||||
go sched.Start(rootCtx)
|
||||
|
||||
// gRPC server.
|
||||
gw := grpcgateway.NewGateway(db, reg, brk)
|
||||
grpcServer := grpc.NewServer()
|
||||
@ -76,6 +85,7 @@ func main() {
|
||||
<-quit
|
||||
|
||||
slog.Info("shutting down")
|
||||
rootCancel()
|
||||
grpcServer.GracefulStop()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -801,3 +802,246 @@ func TestContainerAction_Success(t *testing.T) {
|
||||
t.Error("expected command_id in response")
|
||||
}
|
||||
}
|
||||
|
||||
// newCancelledRequest creates a request with an already-cancelled context.
|
||||
func newCancelledRequest(method, target string, body *bytes.Reader) *http.Request {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
var req *http.Request
|
||||
if body != nil {
|
||||
req = httptest.NewRequest(method, target, body)
|
||||
} else {
|
||||
req = httptest.NewRequest(method, target, nil)
|
||||
}
|
||||
return req.WithContext(ctx)
|
||||
}
|
||||
|
||||
// ── FsList ────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestFsList_AgentNotFound(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Get("/api/v1/agents/{agentID}/fs/list", h.FsList)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/ghost/fs/list?path=/tmp", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFsList_Timeout(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
reg.Register("a1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Get("/api/v1/agents/{agentID}/fs/list", h.FsList)
|
||||
|
||||
// Use cancelled context to force immediate timeout on the agent wait.
|
||||
req := newCancelledRequest(http.MethodGet, "/api/v1/agents/a1/fs/list?path=/tmp", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Either 504 (timeout) or 404 (send failed because channel was full/cancelled).
|
||||
if w.Code != http.StatusGatewayTimeout && w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 504 or 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestFsList_MissingPath(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
reg.Register("a1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Get("/api/v1/agents/{agentID}/fs/list", h.FsList)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a1/fs/list", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── FsRead ────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestFsRead_AgentNotFound(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Get("/api/v1/agents/{agentID}/fs/read", h.FsRead)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/ghost/fs/read?path=/etc/hosts", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFsRead_MissingPath(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
reg.Register("a1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Get("/api/v1/agents/{agentID}/fs/read", h.FsRead)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a1/fs/read", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── FsWrite ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestFsWrite_AgentNotFound(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/fs/write", h.FsWrite)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"path": "/tmp/test.txt", "content": "hello"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/ghost/fs/write", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFsWrite_MissingPath(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
reg.Register("a1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/fs/write", h.FsWrite)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"content": "hello"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a1/fs/write", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── FsMkdir ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestFsMkdir_AgentNotFound(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/fs/mkdir", h.FsMkdir)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"path": "/opt/stacks/nouveau-dossier"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/ghost/fs/mkdir", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFsMkdir_InvalidBody(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
reg.Register("a1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/fs/mkdir", h.FsMkdir)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a1/fs/mkdir", bytes.NewReader([]byte("not-json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── ComposeAction ─────────────────────────────────────────────────────────────
|
||||
|
||||
func TestComposeAction_AgentNotFound(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/compose", h.ComposeAction)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"path": "/opt/stack", "action": "up"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/ghost/compose", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComposeAction_InvalidAction(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
reg.Register("a1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/compose", h.ComposeAction)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"path": "/opt/stack", "action": "restart"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a1/compose", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComposeAction_MissingFields(t *testing.T) {
|
||||
h, _, _, _ := newTestHandler(t)
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/compose", h.ComposeAction)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"action": "up"})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/ghost/compose", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComposeAction_Timeout(t *testing.T) {
|
||||
h, _, reg, _ := newTestHandler(t)
|
||||
reg.Register("a1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Post("/api/v1/agents/{agentID}/compose", h.ComposeAction)
|
||||
|
||||
bodyBytes, _ := json.Marshal(map[string]string{"path": "/opt/stack", "action": "up"})
|
||||
req := newCancelledRequest(http.MethodPost, "/api/v1/agents/a1/compose", bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusGatewayTimeout && w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 504 or 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@ -185,6 +186,7 @@ func (h *Handler) ListImages(w http.ResponseWriter, r *http.Request) {
|
||||
AgentID string `json:"agent_id"`
|
||||
Hostname string `json:"hostname"`
|
||||
Alias string `json:"alias"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
ID string `json:"id"`
|
||||
Tags []string `json:"tags"`
|
||||
Size int64 `json:"size"`
|
||||
@ -197,8 +199,9 @@ func (h *Handler) ListImages(w http.ResponseWriter, r *http.Request) {
|
||||
AgentID: agent.ID,
|
||||
Hostname: agent.Hostname,
|
||||
Alias: agent.Alias,
|
||||
IPAddress: agent.IPAddress,
|
||||
ID: img.GetId(),
|
||||
Tags: img.GetTags(),
|
||||
Tags: func() []string { if t := img.GetTags(); t != nil { return t }; return []string{} }(),
|
||||
Size: img.GetSize(),
|
||||
CreatedAt: img.GetCreatedAt(),
|
||||
})
|
||||
@ -214,6 +217,7 @@ func (h *Handler) ListVolumes(w http.ResponseWriter, r *http.Request) {
|
||||
AgentID string `json:"agent_id"`
|
||||
Hostname string `json:"hostname"`
|
||||
Alias string `json:"alias"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
Name string `json:"name"`
|
||||
Driver string `json:"driver"`
|
||||
Mountpoint string `json:"mountpoint"`
|
||||
@ -225,6 +229,7 @@ func (h *Handler) ListVolumes(w http.ResponseWriter, r *http.Request) {
|
||||
AgentID: agent.ID,
|
||||
Hostname: agent.Hostname,
|
||||
Alias: agent.Alias,
|
||||
IPAddress: agent.IPAddress,
|
||||
Name: vol.GetName(),
|
||||
Driver: vol.GetDriver(),
|
||||
Mountpoint: vol.GetMountpoint(),
|
||||
@ -238,25 +243,27 @@ func (h *Handler) ListVolumes(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (h *Handler) ListNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
type networkDTO struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Hostname string `json:"hostname"`
|
||||
Alias string `json:"alias"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Driver string `json:"driver"`
|
||||
Scope string `json:"scope"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Hostname string `json:"hostname"`
|
||||
Alias string `json:"alias"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Driver string `json:"driver"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
var out []networkDTO
|
||||
for _, agent := range h.registry.List() {
|
||||
for _, net := range agent.Networks {
|
||||
out = append(out, networkDTO{
|
||||
AgentID: agent.ID,
|
||||
Hostname: agent.Hostname,
|
||||
Alias: agent.Alias,
|
||||
ID: net.GetId(),
|
||||
Name: net.GetName(),
|
||||
Driver: net.GetDriver(),
|
||||
Scope: net.GetScope(),
|
||||
AgentID: agent.ID,
|
||||
Hostname: agent.Hostname,
|
||||
Alias: agent.Alias,
|
||||
IPAddress: agent.IPAddress,
|
||||
ID: net.GetId(),
|
||||
Name: net.GetName(),
|
||||
Driver: net.GetDriver(),
|
||||
Scope: net.GetScope(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -390,6 +397,292 @@ func (h *Handler) EventsWS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// ── File system & Compose ─────────────────────────────────────────────────────
|
||||
|
||||
// sendFileCmd sends a file/compose command to an agent and waits for the response.
|
||||
// It uses the request context with an added 30s deadline so the handler can be
|
||||
// tested by cancelling the context.
|
||||
func (h *Handler) sendFileCmd(r *http.Request, agentID string, msg *agentv1.ServerMessage, cmdID string) (*agentv1.FileResult, error) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
return h.registry.SendAndWaitCtx(ctx, agentID, msg, cmdID)
|
||||
}
|
||||
|
||||
// FsList handles GET /api/v1/agents/{agentID}/fs/list?path=/some/dir
|
||||
func (h *Handler) FsList(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
path := r.URL.Query().Get("path")
|
||||
if path == "" {
|
||||
http.Error(w, "path required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cmdID := uuid.NewString()
|
||||
result, err := h.sendFileCmd(r, agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_ListDir{
|
||||
ListDir: &agentv1.ListDirCommand{
|
||||
CommandId: cmdID,
|
||||
Path: path,
|
||||
},
|
||||
},
|
||||
}, cmdID)
|
||||
if err != nil {
|
||||
if err.Error() == "agent not connected" {
|
||||
http.Error(w, "agent not connected", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "timeout waiting for agent", http.StatusGatewayTimeout)
|
||||
return
|
||||
}
|
||||
if !result.Success {
|
||||
http.Error(w, result.Error, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Content is JSON-encoded list of entries from the agent
|
||||
var entries json.RawMessage = result.Content
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(entries)
|
||||
}
|
||||
|
||||
// FsRead handles GET /api/v1/agents/{agentID}/fs/read?path=/some/file
|
||||
func (h *Handler) FsRead(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
path := r.URL.Query().Get("path")
|
||||
if path == "" {
|
||||
http.Error(w, "path required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cmdID := uuid.NewString()
|
||||
result, err := h.sendFileCmd(r, agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_ReadFile{
|
||||
ReadFile: &agentv1.ReadFileCommand{
|
||||
CommandId: cmdID,
|
||||
Path: path,
|
||||
},
|
||||
},
|
||||
}, cmdID)
|
||||
if err != nil {
|
||||
if err.Error() == "agent not connected" {
|
||||
http.Error(w, "agent not connected", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "timeout waiting for agent", http.StatusGatewayTimeout)
|
||||
return
|
||||
}
|
||||
if !result.Success {
|
||||
http.Error(w, result.Error, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
jsonOK(w, map[string]string{"content": string(result.Content)})
|
||||
}
|
||||
|
||||
// FsWrite handles POST /api/v1/agents/{agentID}/fs/write
|
||||
func (h *Handler) FsWrite(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
|
||||
var body struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Path == "" {
|
||||
http.Error(w, "path and content required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cmdID := uuid.NewString()
|
||||
result, err := h.sendFileCmd(r, agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_WriteFile{
|
||||
WriteFile: &agentv1.WriteFileCommand{
|
||||
CommandId: cmdID,
|
||||
Path: body.Path,
|
||||
Content: []byte(body.Content),
|
||||
},
|
||||
},
|
||||
}, cmdID)
|
||||
if err != nil {
|
||||
if err.Error() == "agent not connected" {
|
||||
http.Error(w, "agent not connected", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "timeout waiting for agent", http.StatusGatewayTimeout)
|
||||
return
|
||||
}
|
||||
if !result.Success {
|
||||
http.Error(w, result.Error, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
jsonOK(w, map[string]bool{"ok": true})
|
||||
}
|
||||
|
||||
// FsMkdir handles POST /api/v1/agents/{agentID}/fs/mkdir
|
||||
func (h *Handler) FsMkdir(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
|
||||
var body struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Path == "" {
|
||||
http.Error(w, "path required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cmdID := uuid.NewString()
|
||||
result, err := h.sendFileCmd(r, agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_CreateDir{
|
||||
CreateDir: &agentv1.CreateDirCommand{
|
||||
CommandId: cmdID,
|
||||
Path: body.Path,
|
||||
},
|
||||
},
|
||||
}, cmdID)
|
||||
if err != nil {
|
||||
if err.Error() == "agent not connected" {
|
||||
http.Error(w, "agent not connected", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "timeout waiting for agent", http.StatusGatewayTimeout)
|
||||
return
|
||||
}
|
||||
if !result.Success {
|
||||
http.Error(w, result.Error, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
jsonOK(w, map[string]bool{"ok": true})
|
||||
}
|
||||
|
||||
// ComposeAction handles POST /api/v1/agents/{agentID}/compose
|
||||
func (h *Handler) ComposeAction(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
|
||||
var body struct {
|
||||
Path string `json:"path"`
|
||||
Action string `json:"action"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Path == "" || body.Action == "" {
|
||||
http.Error(w, "path and action required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
validActions := map[string]bool{"up": true, "down": true, "pull": true}
|
||||
if !validActions[body.Action] {
|
||||
http.Error(w, "action must be one of: up, down, pull", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cmdID := uuid.NewString()
|
||||
result, err := h.sendFileCmd(r, agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_ExecCompose{
|
||||
ExecCompose: &agentv1.ExecComposeCommand{
|
||||
CommandId: cmdID,
|
||||
Path: body.Path,
|
||||
Action: body.Action,
|
||||
},
|
||||
},
|
||||
}, cmdID)
|
||||
if err != nil {
|
||||
if err.Error() == "agent not connected" {
|
||||
http.Error(w, "agent not connected", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "timeout waiting for agent", http.StatusGatewayTimeout)
|
||||
return
|
||||
}
|
||||
if !result.Success {
|
||||
jsonErr, _ := json.Marshal(map[string]string{"error": result.Error, "output": string(result.Content)})
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write(jsonErr)
|
||||
return
|
||||
}
|
||||
|
||||
jsonOK(w, map[string]any{"ok": true, "output": string(result.Content)})
|
||||
}
|
||||
|
||||
// ── Auto-update policies ──────────────────────────────────────────────────────
|
||||
|
||||
// GetAutoUpdatePolicy handles GET /api/v1/agents/{agentID}/containers/{containerID}/auto-update
|
||||
func (h *Handler) GetAutoUpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
containerID := chi.URLParam(r, "containerID")
|
||||
|
||||
p, err := h.store.GetAutoUpdatePolicy(agentID, containerID)
|
||||
if err != nil {
|
||||
http.Error(w, "store error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if p == nil {
|
||||
jsonOK(w, map[string]any{"enabled": false, "interval_minutes": 1440})
|
||||
return
|
||||
}
|
||||
jsonOK(w, map[string]any{
|
||||
"enabled": p.Enabled,
|
||||
"interval_minutes": p.IntervalMinutes,
|
||||
"last_checked_at": p.LastCheckedAt,
|
||||
"last_updated_at": p.LastUpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
// PutAutoUpdatePolicy handles PUT /api/v1/agents/{agentID}/containers/{containerID}/auto-update
|
||||
func (h *Handler) PutAutoUpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
containerID := chi.URLParam(r, "containerID")
|
||||
|
||||
var body struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
IntervalMinutes int `json:"interval_minutes"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, "invalid body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if body.IntervalMinutes < 60 || body.IntervalMinutes > 43200 {
|
||||
http.Error(w, "interval_minutes must be between 60 and 43200", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
p := &store.AutoUpdatePolicy{
|
||||
AgentID: agentID,
|
||||
ContainerID: containerID,
|
||||
Enabled: body.Enabled,
|
||||
IntervalMinutes: body.IntervalMinutes,
|
||||
}
|
||||
if err := h.store.UpsertAutoUpdatePolicy(p); err != nil {
|
||||
http.Error(w, "store error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
jsonOK(w, map[string]any{
|
||||
"enabled": p.Enabled,
|
||||
"interval_minutes": p.IntervalMinutes,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateNow handles POST /api/v1/agents/{agentID}/containers/{containerID}/update-now
|
||||
func (h *Handler) UpdateNow(w http.ResponseWriter, r *http.Request) {
|
||||
agentID := chi.URLParam(r, "agentID")
|
||||
containerID := chi.URLParam(r, "containerID")
|
||||
|
||||
cmdID := uuid.NewString()
|
||||
sent := h.registry.Send(agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_UpdateContainer{
|
||||
UpdateContainer: &agentv1.UpdateContainerCommand{
|
||||
CommandId: cmdID,
|
||||
ContainerId: containerID,
|
||||
},
|
||||
},
|
||||
})
|
||||
if !sent {
|
||||
http.Error(w, "agent not connected", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
h.registry.RegisterPendingUpdate(agentID, cmdID, containerID)
|
||||
jsonOK(w, map[string]string{"command_id": cmdID})
|
||||
}
|
||||
|
||||
func jsonOK(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(v)
|
||||
|
||||
@ -52,6 +52,14 @@ func NewRouter(h *Handler) http.Handler {
|
||||
r.Post("/agents/{agentID}/containers/{containerID}/action", h.ContainerAction)
|
||||
r.Get("/agents/{agentID}/containers/{containerID}/logs", h.LogsWS)
|
||||
r.Get("/events", h.EventsWS)
|
||||
r.Get("/agents/{agentID}/fs/list", h.FsList)
|
||||
r.Get("/agents/{agentID}/fs/read", h.FsRead)
|
||||
r.Post("/agents/{agentID}/fs/write", h.FsWrite)
|
||||
r.Post("/agents/{agentID}/fs/mkdir", h.FsMkdir)
|
||||
r.Post("/agents/{agentID}/compose", h.ComposeAction)
|
||||
r.Get("/agents/{agentID}/containers/{containerID}/auto-update", h.GetAutoUpdatePolicy)
|
||||
r.Put("/agents/{agentID}/containers/{containerID}/auto-update", h.PutAutoUpdatePolicy)
|
||||
r.Post("/agents/{agentID}/containers/{containerID}/update-now", h.UpdateNow)
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/containarr/server/internal/broker"
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
@ -125,11 +126,21 @@ func (g *Gateway) Tunnel(stream agentv1.AgentGateway_TunnelServer) error {
|
||||
})
|
||||
|
||||
case *agentv1.AgentMessage_Result:
|
||||
res := p.Result
|
||||
g.broker.Publish(broker.Event{
|
||||
Type: "command.result",
|
||||
AgentID: agentID,
|
||||
Payload: p.Result,
|
||||
Payload: res,
|
||||
})
|
||||
if containerID, found := g.registry.ResolvePendingUpdate(agentID, res.CommandId); found {
|
||||
now := time.Now()
|
||||
_ = g.store.UpdateAutoUpdateChecked(agentID, containerID, now)
|
||||
if res.Success {
|
||||
_ = g.store.UpdateAutoUpdateDone(agentID, containerID, now)
|
||||
} else {
|
||||
slog.Warn("update container failed", "agent_id", agentID, "container_id", containerID, "error", res.Error)
|
||||
}
|
||||
}
|
||||
|
||||
case *agentv1.AgentMessage_LogChunk:
|
||||
g.broker.Publish(broker.Event{
|
||||
@ -137,6 +148,29 @@ func (g *Gateway) Tunnel(stream agentv1.AgentGateway_TunnelServer) error {
|
||||
AgentID: agentID,
|
||||
Payload: p.LogChunk,
|
||||
})
|
||||
|
||||
case *agentv1.AgentMessage_FileResult:
|
||||
g.registry.ResolvePending(agentID, p.FileResult.CommandId, p.FileResult)
|
||||
|
||||
case *agentv1.AgentMessage_UpdateCheckResult:
|
||||
res := p.UpdateCheckResult
|
||||
if res.Error != "" {
|
||||
slog.Warn("update check error", "agent_id", agentID, "container_id", res.ContainerId, "error", res.Error)
|
||||
}
|
||||
_ = g.store.UpdateAutoUpdateChecked(agentID, res.ContainerId, time.Now())
|
||||
if res.UpdateAvailable {
|
||||
cmdID := newCommandID()
|
||||
slog.Info("update available, triggering UpdateContainerCommand", "agent_id", agentID, "container_id", res.ContainerId, "command_id", cmdID)
|
||||
g.registry.Send(agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_UpdateContainer{
|
||||
UpdateContainer: &agentv1.UpdateContainerCommand{
|
||||
CommandId: cmdID,
|
||||
ContainerId: res.ContainerId,
|
||||
},
|
||||
},
|
||||
})
|
||||
g.registry.RegisterPendingUpdate(agentID, cmdID, res.ContainerId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -20,7 +22,10 @@ type AgentState struct {
|
||||
Volumes []*agentv1.VolumeInfo
|
||||
Networks []*agentv1.NetworkInfo
|
||||
|
||||
cmdCh chan *agentv1.ServerMessage
|
||||
cmdCh chan *agentv1.ServerMessage
|
||||
pendingFiles map[string]chan *agentv1.FileResult
|
||||
pendingUpdates map[string]string // commandID → containerID
|
||||
pendingMu sync.Mutex
|
||||
}
|
||||
|
||||
type Registry struct {
|
||||
@ -34,13 +39,15 @@ func NewRegistry() *Registry {
|
||||
|
||||
func (r *Registry) Register(id, hostname, alias, ipAddress, arch, os string) *AgentState {
|
||||
state := &AgentState{
|
||||
ID: id,
|
||||
Hostname: hostname,
|
||||
Alias: alias,
|
||||
IPAddress: ipAddress,
|
||||
Arch: arch,
|
||||
OS: os,
|
||||
cmdCh: make(chan *agentv1.ServerMessage, 16),
|
||||
ID: id,
|
||||
Hostname: hostname,
|
||||
Alias: alias,
|
||||
IPAddress: ipAddress,
|
||||
Arch: arch,
|
||||
OS: os,
|
||||
cmdCh: make(chan *agentv1.ServerMessage, 16),
|
||||
pendingFiles: make(map[string]chan *agentv1.FileResult),
|
||||
pendingUpdates: make(map[string]string),
|
||||
}
|
||||
r.mu.Lock()
|
||||
r.agents[id] = state
|
||||
@ -118,3 +125,113 @@ func (r *Registry) Send(agentID string, msg *agentv1.ServerMessage) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterPending registers a channel waiting for a FileResult with the given cmdID.
|
||||
func (r *Registry) RegisterPending(agentID, cmdID string) chan *agentv1.FileResult {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ch := make(chan *agentv1.FileResult, 1)
|
||||
s.pendingMu.Lock()
|
||||
s.pendingFiles[cmdID] = ch
|
||||
s.pendingMu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
// ResolvePending sends the FileResult to the waiting channel identified by cmdID.
|
||||
func (r *Registry) ResolvePending(agentID, cmdID string, result *agentv1.FileResult) {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.pendingMu.Lock()
|
||||
ch, ok := s.pendingFiles[cmdID]
|
||||
if ok {
|
||||
delete(s.pendingFiles, cmdID)
|
||||
}
|
||||
s.pendingMu.Unlock()
|
||||
if ok {
|
||||
select {
|
||||
case ch <- result:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CancelPending removes the pending channel for cmdID (cleanup on timeout).
|
||||
func (r *Registry) CancelPending(agentID, cmdID string) {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.pendingMu.Lock()
|
||||
delete(s.pendingFiles, cmdID)
|
||||
s.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
// SendAndWait registers a pending channel, sends msg to the agent, and waits up
|
||||
// to 30 seconds for the FileResult response identified by cmdID.
|
||||
func (r *Registry) SendAndWait(agentID string, msg *agentv1.ServerMessage, cmdID string) (*agentv1.FileResult, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return r.SendAndWaitCtx(ctx, agentID, msg, cmdID)
|
||||
}
|
||||
|
||||
// RegisterPendingUpdate enregistre un commandID en attente de CommandResult pour un UpdateContainer.
|
||||
func (r *Registry) RegisterPendingUpdate(agentID, cmdID, containerID string) {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.pendingMu.Lock()
|
||||
s.pendingUpdates[cmdID] = containerID
|
||||
s.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
// ResolvePendingUpdate retourne le containerID associé au commandID et le supprime de la map.
|
||||
// Retourne ("", false) si le commandID n'est pas connu.
|
||||
func (r *Registry) ResolvePendingUpdate(agentID, cmdID string) (string, bool) {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s.pendingMu.Lock()
|
||||
containerID, found := s.pendingUpdates[cmdID]
|
||||
if found {
|
||||
delete(s.pendingUpdates, cmdID)
|
||||
}
|
||||
s.pendingMu.Unlock()
|
||||
return containerID, found
|
||||
}
|
||||
|
||||
// SendAndWaitCtx is like SendAndWait but uses the provided context for timeout control.
|
||||
func (r *Registry) SendAndWaitCtx(ctx context.Context, agentID string, msg *agentv1.ServerMessage, cmdID string) (*agentv1.FileResult, error) {
|
||||
ch := r.RegisterPending(agentID, cmdID)
|
||||
if ch == nil {
|
||||
return nil, fmt.Errorf("agent not connected")
|
||||
}
|
||||
|
||||
if !r.Send(agentID, msg) {
|
||||
r.CancelPending(agentID, cmdID)
|
||||
return nil, fmt.Errorf("agent not connected")
|
||||
}
|
||||
|
||||
select {
|
||||
case result := <-ch:
|
||||
return result, nil
|
||||
case <-ctx.Done():
|
||||
r.CancelPending(agentID, cmdID)
|
||||
return nil, fmt.Errorf("timeout waiting for agent response")
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -153,3 +154,112 @@ func TestSend_FullChannel(t *testing.T) {
|
||||
t.Error("Send should return false when channel is full")
|
||||
}
|
||||
}
|
||||
|
||||
// ── Pending file correlations ──────────────────────────────────────────────────
|
||||
|
||||
func TestRegisterPending_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
ch := r.RegisterPending("ghost", "cmd1")
|
||||
if ch != nil {
|
||||
t.Error("expected nil channel for unknown agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePending_Success(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
ch := r.RegisterPending("id1", "cmd1")
|
||||
if ch == nil {
|
||||
t.Fatal("expected non-nil channel")
|
||||
}
|
||||
|
||||
result := &agentv1.FileResult{CommandId: "cmd1", Success: true, Content: []byte("data")}
|
||||
r.ResolvePending("id1", "cmd1", result)
|
||||
|
||||
select {
|
||||
case got := <-ch:
|
||||
if got.CommandId != "cmd1" || !got.Success {
|
||||
t.Errorf("unexpected result: %+v", got)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for resolve")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePending_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
// must not panic
|
||||
r.ResolvePending("ghost", "cmd1", &agentv1.FileResult{})
|
||||
}
|
||||
|
||||
func TestResolvePending_UnknownCmd(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
// must not panic
|
||||
r.ResolvePending("id1", "nonexistent", &agentv1.FileResult{})
|
||||
}
|
||||
|
||||
func TestCancelPending(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
r.RegisterPending("id1", "cmd1")
|
||||
r.CancelPending("id1", "cmd1")
|
||||
|
||||
// After cancel, resolving should be a no-op (not panic)
|
||||
r.ResolvePending("id1", "cmd1", &agentv1.FileResult{})
|
||||
}
|
||||
|
||||
func TestCancelPending_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
// must not panic
|
||||
r.CancelPending("ghost", "cmd1")
|
||||
}
|
||||
|
||||
func TestSendAndWaitCtx_AgentNotConnected(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
ctx := context.Background()
|
||||
_, err := r.SendAndWaitCtx(ctx, "ghost", &agentv1.ServerMessage{}, "cmd1")
|
||||
if err == nil || err.Error() != "agent not connected" {
|
||||
t.Errorf("expected 'agent not connected', got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendAndWaitCtx_Timeout(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
// Use an already-cancelled context to force immediate timeout.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancel immediately
|
||||
|
||||
_, err := r.SendAndWaitCtx(ctx, "id1", &agentv1.ServerMessage{}, "cmd-timeout")
|
||||
if err == nil {
|
||||
t.Error("expected timeout or not-connected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendAndWaitCtx_Success(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
cmdID := "cmd-success"
|
||||
expected := &agentv1.FileResult{CommandId: cmdID, Success: true, Content: []byte("hello")}
|
||||
|
||||
// Simulate the agent responding after the send.
|
||||
go func() {
|
||||
// Wait briefly for RegisterPending + Send to happen.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
r.ResolvePending("id1", cmdID, expected)
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := r.SendAndWaitCtx(ctx, "id1", &agentv1.ServerMessage{}, cmdID)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.CommandId != cmdID || !result.Success {
|
||||
t.Errorf("unexpected result: %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
39
server/internal/scheduler/adapter.go
Normal file
39
server/internal/scheduler/adapter.go
Normal file
@ -0,0 +1,39 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/containarr/server/internal/store"
|
||||
)
|
||||
|
||||
// StoreAdapter wraps *store.Store so it satisfies StoreInterface.
|
||||
type StoreAdapter struct {
|
||||
s *store.Store
|
||||
}
|
||||
|
||||
// NewStoreAdapter creates a StoreAdapter wrapping the given *store.Store.
|
||||
func NewStoreAdapter(s *store.Store) *StoreAdapter {
|
||||
return &StoreAdapter{s: s}
|
||||
}
|
||||
|
||||
// ListDueAutoUpdatePolicies implements StoreInterface by converting
|
||||
// *store.AutoUpdatePolicy to DuePolicy.
|
||||
func (a *StoreAdapter) ListDueAutoUpdatePolicies(now time.Time) ([]DuePolicy, error) {
|
||||
policies, err := a.s.ListDueAutoUpdatePolicies(now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]DuePolicy, 0, len(policies))
|
||||
for _, p := range policies {
|
||||
out = append(out, DuePolicy{
|
||||
AgentID: p.AgentID,
|
||||
ContainerID: p.ContainerID,
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// UpdateAutoUpdateChecked implements StoreInterface.
|
||||
func (a *StoreAdapter) UpdateAutoUpdateChecked(agentID, containerID string, at time.Time) error {
|
||||
return a.s.UpdateAutoUpdateChecked(agentID, containerID, at)
|
||||
}
|
||||
86
server/internal/scheduler/scheduler.go
Normal file
86
server/internal/scheduler/scheduler.go
Normal file
@ -0,0 +1,86 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// DuePolicy is a minimal view of an auto-update policy returned by the store.
|
||||
type DuePolicy struct {
|
||||
AgentID string
|
||||
ContainerID string
|
||||
}
|
||||
|
||||
// StoreInterface defines the minimal store methods used by the scheduler.
|
||||
// Implementations must convert their internal policy type to DuePolicy when
|
||||
// implementing ListDueAutoUpdatePolicies, or use StoreAdapter provided below.
|
||||
type StoreInterface interface {
|
||||
ListDueAutoUpdatePolicies(now time.Time) ([]DuePolicy, error)
|
||||
UpdateAutoUpdateChecked(agentID, containerID string, at time.Time) error
|
||||
}
|
||||
|
||||
// RegistryInterface defines the minimal registry methods used by the scheduler.
|
||||
type RegistryInterface interface {
|
||||
Send(agentID string, msg *agentv1.ServerMessage) bool
|
||||
}
|
||||
|
||||
// Scheduler sends CheckUpdateCommand to agents every 60 seconds for containers
|
||||
// with an active and due auto-update policy.
|
||||
type Scheduler struct {
|
||||
store StoreInterface
|
||||
registry RegistryInterface
|
||||
}
|
||||
|
||||
// New creates a new Scheduler.
|
||||
func New(store StoreInterface, registry RegistryInterface) *Scheduler {
|
||||
return &Scheduler{store: store, registry: registry}
|
||||
}
|
||||
|
||||
// Start runs the scheduler loop until ctx is cancelled.
|
||||
func (s *Scheduler) Start(ctx context.Context) {
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("scheduler stopped")
|
||||
return
|
||||
case t := <-ticker.C:
|
||||
s.tick(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) tick(now time.Time) {
|
||||
policies, err := s.store.ListDueAutoUpdatePolicies(now)
|
||||
if err != nil {
|
||||
slog.Error("scheduler: list due policies", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, p := range policies {
|
||||
cmdID := uuid.NewString()
|
||||
msg := &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_CheckUpdate{
|
||||
CheckUpdate: &agentv1.CheckUpdateCommand{
|
||||
CommandId: cmdID,
|
||||
ContainerId: p.ContainerID,
|
||||
},
|
||||
},
|
||||
}
|
||||
sent := s.registry.Send(p.AgentID, msg)
|
||||
if !sent {
|
||||
slog.Debug("scheduler: agent not connected, skipping", "agent_id", p.AgentID, "container_id", p.ContainerID)
|
||||
continue
|
||||
}
|
||||
if err := s.store.UpdateAutoUpdateChecked(p.AgentID, p.ContainerID, now); err != nil {
|
||||
slog.Error("scheduler: update last_checked_at", "agent_id", p.AgentID, "container_id", p.ContainerID, "err", err)
|
||||
}
|
||||
slog.Info("scheduler: sent CheckUpdateCommand", "agent_id", p.AgentID, "container_id", p.ContainerID, "command_id", cmdID)
|
||||
}
|
||||
}
|
||||
@ -49,6 +49,16 @@ func (s *Store) migrate() error {
|
||||
last_seen_at DATETIME,
|
||||
online INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS auto_update_policies (
|
||||
agent_id TEXT NOT NULL,
|
||||
container_id TEXT NOT NULL,
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
interval_minutes INTEGER NOT NULL DEFAULT 1440,
|
||||
last_checked_at DATETIME,
|
||||
last_updated_at DATETIME,
|
||||
PRIMARY KEY (agent_id, container_id),
|
||||
FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE CASCADE
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -186,3 +196,106 @@ func boolToInt(b bool) int {
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// ── AutoUpdatePolicies ────────────────────────────────────────────────────────
|
||||
|
||||
type AutoUpdatePolicy struct {
|
||||
AgentID string
|
||||
ContainerID string
|
||||
Enabled bool
|
||||
IntervalMinutes int
|
||||
LastCheckedAt *time.Time
|
||||
LastUpdatedAt *time.Time
|
||||
}
|
||||
|
||||
func (s *Store) UpsertAutoUpdatePolicy(p *AutoUpdatePolicy) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT OR REPLACE INTO auto_update_policies
|
||||
(agent_id, container_id, enabled, interval_minutes, last_checked_at, last_updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
`, p.AgentID, p.ContainerID, boolToInt(p.Enabled), p.IntervalMinutes, p.LastCheckedAt, p.LastUpdatedAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) GetAutoUpdatePolicy(agentID, containerID string) (*AutoUpdatePolicy, error) {
|
||||
row := s.db.QueryRow(`
|
||||
SELECT agent_id, container_id, enabled, interval_minutes, last_checked_at, last_updated_at
|
||||
FROM auto_update_policies WHERE agent_id = ? AND container_id = ?
|
||||
`, agentID, containerID)
|
||||
p := &AutoUpdatePolicy{}
|
||||
var enabled int
|
||||
var lastChecked, lastUpdated sql.NullTime
|
||||
err := row.Scan(&p.AgentID, &p.ContainerID, &enabled, &p.IntervalMinutes, &lastChecked, &lastUpdated)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Enabled = enabled == 1
|
||||
if lastChecked.Valid {
|
||||
t := lastChecked.Time
|
||||
p.LastCheckedAt = &t
|
||||
}
|
||||
if lastUpdated.Valid {
|
||||
t := lastUpdated.Time
|
||||
p.LastUpdatedAt = &t
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListDueAutoUpdatePolicies(now time.Time) ([]*AutoUpdatePolicy, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT agent_id, container_id, enabled, interval_minutes, last_checked_at, last_updated_at
|
||||
FROM auto_update_policies
|
||||
WHERE enabled = 1
|
||||
AND (last_checked_at IS NULL
|
||||
OR (julianday(?) - julianday(last_checked_at)) * 1440 >= interval_minutes)
|
||||
`, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var policies []*AutoUpdatePolicy
|
||||
for rows.Next() {
|
||||
p := &AutoUpdatePolicy{}
|
||||
var enabled int
|
||||
var lastChecked, lastUpdated sql.NullTime
|
||||
if err := rows.Scan(&p.AgentID, &p.ContainerID, &enabled, &p.IntervalMinutes, &lastChecked, &lastUpdated); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Enabled = enabled == 1
|
||||
if lastChecked.Valid {
|
||||
t := lastChecked.Time
|
||||
p.LastCheckedAt = &t
|
||||
}
|
||||
if lastUpdated.Valid {
|
||||
t := lastUpdated.Time
|
||||
p.LastUpdatedAt = &t
|
||||
}
|
||||
policies = append(policies, p)
|
||||
}
|
||||
return policies, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) UpdateAutoUpdateChecked(agentID, containerID string, at time.Time) error {
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE auto_update_policies SET last_checked_at = ? WHERE agent_id = ? AND container_id = ?
|
||||
`, at, agentID, containerID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) UpdateAutoUpdateDone(agentID, containerID string, at time.Time) error {
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE auto_update_policies SET last_updated_at = ? WHERE agent_id = ? AND container_id = ?
|
||||
`, at, agentID, containerID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) DeleteAutoUpdatePolicy(agentID, containerID string) error {
|
||||
_, err := s.db.Exec(`
|
||||
DELETE FROM auto_update_policies WHERE agent_id = ? AND container_id = ?
|
||||
`, agentID, containerID)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@ package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTestStore(t *testing.T) *Store {
|
||||
@ -254,3 +255,199 @@ func TestCreateAgentToken_IdempotentIgnore(t *testing.T) {
|
||||
t.Fatalf("second call (should be idempotent): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── AutoUpdatePolicies ────────────────────────────────────────────────────────
|
||||
|
||||
// helper: create an agent prerequisite for FK constraints.
|
||||
func createAgent(t *testing.T, s *Store, id, token, hostname string) {
|
||||
t.Helper()
|
||||
if err := s.CreateAgentToken(id, token, hostname); err != nil {
|
||||
t.Fatalf("createAgent: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertAndGetAutoUpdatePolicy(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
createAgent(t, s, "ag1", "tok1", "host1")
|
||||
|
||||
p := &AutoUpdatePolicy{
|
||||
AgentID: "ag1",
|
||||
ContainerID: "ctr1",
|
||||
Enabled: true,
|
||||
IntervalMinutes: 60,
|
||||
}
|
||||
if err := s.UpsertAutoUpdatePolicy(p); err != nil {
|
||||
t.Fatalf("UpsertAutoUpdatePolicy: %v", err)
|
||||
}
|
||||
|
||||
got, err := s.GetAutoUpdatePolicy("ag1", "ctr1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAutoUpdatePolicy: %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("expected policy, got nil")
|
||||
}
|
||||
if !got.Enabled || got.IntervalMinutes != 60 {
|
||||
t.Errorf("unexpected policy: %+v", got)
|
||||
}
|
||||
if got.LastCheckedAt != nil || got.LastUpdatedAt != nil {
|
||||
t.Error("expected nil timestamps on fresh policy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAutoUpdatePolicy_NotFound(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
p, err := s.GetAutoUpdatePolicy("nobody", "ctr")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if p != nil {
|
||||
t.Errorf("expected nil, got %+v", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertAutoUpdatePolicy_Update(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
createAgent(t, s, "ag1", "tok1", "host1")
|
||||
|
||||
_ = s.UpsertAutoUpdatePolicy(&AutoUpdatePolicy{AgentID: "ag1", ContainerID: "ctr1", Enabled: true, IntervalMinutes: 60})
|
||||
_ = s.UpsertAutoUpdatePolicy(&AutoUpdatePolicy{AgentID: "ag1", ContainerID: "ctr1", Enabled: false, IntervalMinutes: 1440})
|
||||
|
||||
got, err := s.GetAutoUpdatePolicy("ag1", "ctr1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAutoUpdatePolicy: %v", err)
|
||||
}
|
||||
if got.Enabled || got.IntervalMinutes != 1440 {
|
||||
t.Errorf("expected updated policy, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAutoUpdateChecked(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
createAgent(t, s, "ag1", "tok1", "host1")
|
||||
_ = s.UpsertAutoUpdatePolicy(&AutoUpdatePolicy{AgentID: "ag1", ContainerID: "ctr1", Enabled: true, IntervalMinutes: 60})
|
||||
|
||||
now := time.Now().Truncate(time.Second)
|
||||
if err := s.UpdateAutoUpdateChecked("ag1", "ctr1", now); err != nil {
|
||||
t.Fatalf("UpdateAutoUpdateChecked: %v", err)
|
||||
}
|
||||
|
||||
got, _ := s.GetAutoUpdatePolicy("ag1", "ctr1")
|
||||
if got.LastCheckedAt == nil {
|
||||
t.Fatal("expected LastCheckedAt to be set")
|
||||
}
|
||||
if got.LastCheckedAt.UTC().Truncate(time.Second) != now.UTC() {
|
||||
t.Errorf("expected %v, got %v", now.UTC(), got.LastCheckedAt.UTC().Truncate(time.Second))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAutoUpdateDone(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
createAgent(t, s, "ag1", "tok1", "host1")
|
||||
_ = s.UpsertAutoUpdatePolicy(&AutoUpdatePolicy{AgentID: "ag1", ContainerID: "ctr1", Enabled: true, IntervalMinutes: 60})
|
||||
|
||||
now := time.Now().Truncate(time.Second)
|
||||
if err := s.UpdateAutoUpdateDone("ag1", "ctr1", now); err != nil {
|
||||
t.Fatalf("UpdateAutoUpdateDone: %v", err)
|
||||
}
|
||||
|
||||
got, _ := s.GetAutoUpdatePolicy("ag1", "ctr1")
|
||||
if got.LastUpdatedAt == nil {
|
||||
t.Fatal("expected LastUpdatedAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDueAutoUpdatePolicies_NullLastChecked(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
createAgent(t, s, "ag1", "tok1", "host1")
|
||||
_ = s.UpsertAutoUpdatePolicy(&AutoUpdatePolicy{AgentID: "ag1", ContainerID: "ctr1", Enabled: true, IntervalMinutes: 60})
|
||||
|
||||
// last_checked_at IS NULL → should be due immediately.
|
||||
due, err := s.ListDueAutoUpdatePolicies(time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("ListDueAutoUpdatePolicies: %v", err)
|
||||
}
|
||||
if len(due) != 1 {
|
||||
t.Fatalf("expected 1 due policy, got %d", len(due))
|
||||
}
|
||||
if due[0].ContainerID != "ctr1" {
|
||||
t.Errorf("unexpected container: %q", due[0].ContainerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDueAutoUpdatePolicies_NotDueYet(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
createAgent(t, s, "ag1", "tok1", "host1")
|
||||
_ = s.UpsertAutoUpdatePolicy(&AutoUpdatePolicy{AgentID: "ag1", ContainerID: "ctr1", Enabled: true, IntervalMinutes: 1440})
|
||||
|
||||
// Mark as just checked — not due yet.
|
||||
_ = s.UpdateAutoUpdateChecked("ag1", "ctr1", time.Now())
|
||||
|
||||
due, err := s.ListDueAutoUpdatePolicies(time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("ListDueAutoUpdatePolicies: %v", err)
|
||||
}
|
||||
if len(due) != 0 {
|
||||
t.Fatalf("expected 0 due policies (just checked), got %d", len(due))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDueAutoUpdatePolicies_Due(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
createAgent(t, s, "ag1", "tok1", "host1")
|
||||
_ = s.UpsertAutoUpdatePolicy(&AutoUpdatePolicy{AgentID: "ag1", ContainerID: "ctr1", Enabled: true, IntervalMinutes: 60})
|
||||
|
||||
// Simulate last check 2 hours ago → should be due.
|
||||
past := time.Now().Add(-2 * time.Hour)
|
||||
_ = s.UpdateAutoUpdateChecked("ag1", "ctr1", past)
|
||||
|
||||
due, err := s.ListDueAutoUpdatePolicies(time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("ListDueAutoUpdatePolicies: %v", err)
|
||||
}
|
||||
if len(due) != 1 {
|
||||
t.Fatalf("expected 1 due policy (overdue), got %d", len(due))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDueAutoUpdatePolicies_DisabledExcluded(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
createAgent(t, s, "ag1", "tok1", "host1")
|
||||
_ = s.UpsertAutoUpdatePolicy(&AutoUpdatePolicy{AgentID: "ag1", ContainerID: "ctr1", Enabled: false, IntervalMinutes: 60})
|
||||
|
||||
due, err := s.ListDueAutoUpdatePolicies(time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("ListDueAutoUpdatePolicies: %v", err)
|
||||
}
|
||||
if len(due) != 0 {
|
||||
t.Fatalf("expected 0 due policies (disabled), got %d", len(due))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAutoUpdatePolicy(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
createAgent(t, s, "ag1", "tok1", "host1")
|
||||
_ = s.UpsertAutoUpdatePolicy(&AutoUpdatePolicy{AgentID: "ag1", ContainerID: "ctr1", Enabled: true, IntervalMinutes: 60})
|
||||
|
||||
if err := s.DeleteAutoUpdatePolicy("ag1", "ctr1"); err != nil {
|
||||
t.Fatalf("DeleteAutoUpdatePolicy: %v", err)
|
||||
}
|
||||
|
||||
got, err := s.GetAutoUpdatePolicy("ag1", "ctr1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAutoUpdatePolicy: %v", err)
|
||||
}
|
||||
if got != nil {
|
||||
t.Error("expected nil after deletion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAutoUpdatePolicy_Idempotent(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
// Deleting a non-existent policy should not error.
|
||||
if err := s.DeleteAutoUpdatePolicy("nobody", "ctr"); err != nil {
|
||||
t.Fatalf("DeleteAutoUpdatePolicy on missing: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user