feat: add auto update
This commit is contained in:
@ -4,6 +4,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/containarr/server/internal/broker"
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
@ -125,11 +126,21 @@ func (g *Gateway) Tunnel(stream agentv1.AgentGateway_TunnelServer) error {
|
||||
})
|
||||
|
||||
case *agentv1.AgentMessage_Result:
|
||||
res := p.Result
|
||||
g.broker.Publish(broker.Event{
|
||||
Type: "command.result",
|
||||
AgentID: agentID,
|
||||
Payload: p.Result,
|
||||
Payload: res,
|
||||
})
|
||||
if containerID, found := g.registry.ResolvePendingUpdate(agentID, res.CommandId); found {
|
||||
now := time.Now()
|
||||
_ = g.store.UpdateAutoUpdateChecked(agentID, containerID, now)
|
||||
if res.Success {
|
||||
_ = g.store.UpdateAutoUpdateDone(agentID, containerID, now)
|
||||
} else {
|
||||
slog.Warn("update container failed", "agent_id", agentID, "container_id", containerID, "error", res.Error)
|
||||
}
|
||||
}
|
||||
|
||||
case *agentv1.AgentMessage_LogChunk:
|
||||
g.broker.Publish(broker.Event{
|
||||
@ -137,6 +148,29 @@ func (g *Gateway) Tunnel(stream agentv1.AgentGateway_TunnelServer) error {
|
||||
AgentID: agentID,
|
||||
Payload: p.LogChunk,
|
||||
})
|
||||
|
||||
case *agentv1.AgentMessage_FileResult:
|
||||
g.registry.ResolvePending(agentID, p.FileResult.CommandId, p.FileResult)
|
||||
|
||||
case *agentv1.AgentMessage_UpdateCheckResult:
|
||||
res := p.UpdateCheckResult
|
||||
if res.Error != "" {
|
||||
slog.Warn("update check error", "agent_id", agentID, "container_id", res.ContainerId, "error", res.Error)
|
||||
}
|
||||
_ = g.store.UpdateAutoUpdateChecked(agentID, res.ContainerId, time.Now())
|
||||
if res.UpdateAvailable {
|
||||
cmdID := newCommandID()
|
||||
slog.Info("update available, triggering UpdateContainerCommand", "agent_id", agentID, "container_id", res.ContainerId, "command_id", cmdID)
|
||||
g.registry.Send(agentID, &agentv1.ServerMessage{
|
||||
Payload: &agentv1.ServerMessage_UpdateContainer{
|
||||
UpdateContainer: &agentv1.UpdateContainerCommand{
|
||||
CommandId: cmdID,
|
||||
ContainerId: res.ContainerId,
|
||||
},
|
||||
},
|
||||
})
|
||||
g.registry.RegisterPendingUpdate(agentID, cmdID, res.ContainerId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -20,7 +22,10 @@ type AgentState struct {
|
||||
Volumes []*agentv1.VolumeInfo
|
||||
Networks []*agentv1.NetworkInfo
|
||||
|
||||
cmdCh chan *agentv1.ServerMessage
|
||||
cmdCh chan *agentv1.ServerMessage
|
||||
pendingFiles map[string]chan *agentv1.FileResult
|
||||
pendingUpdates map[string]string // commandID → containerID
|
||||
pendingMu sync.Mutex
|
||||
}
|
||||
|
||||
type Registry struct {
|
||||
@ -34,13 +39,15 @@ func NewRegistry() *Registry {
|
||||
|
||||
func (r *Registry) Register(id, hostname, alias, ipAddress, arch, os string) *AgentState {
|
||||
state := &AgentState{
|
||||
ID: id,
|
||||
Hostname: hostname,
|
||||
Alias: alias,
|
||||
IPAddress: ipAddress,
|
||||
Arch: arch,
|
||||
OS: os,
|
||||
cmdCh: make(chan *agentv1.ServerMessage, 16),
|
||||
ID: id,
|
||||
Hostname: hostname,
|
||||
Alias: alias,
|
||||
IPAddress: ipAddress,
|
||||
Arch: arch,
|
||||
OS: os,
|
||||
cmdCh: make(chan *agentv1.ServerMessage, 16),
|
||||
pendingFiles: make(map[string]chan *agentv1.FileResult),
|
||||
pendingUpdates: make(map[string]string),
|
||||
}
|
||||
r.mu.Lock()
|
||||
r.agents[id] = state
|
||||
@ -118,3 +125,113 @@ func (r *Registry) Send(agentID string, msg *agentv1.ServerMessage) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterPending registers a channel waiting for a FileResult with the given cmdID.
|
||||
func (r *Registry) RegisterPending(agentID, cmdID string) chan *agentv1.FileResult {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ch := make(chan *agentv1.FileResult, 1)
|
||||
s.pendingMu.Lock()
|
||||
s.pendingFiles[cmdID] = ch
|
||||
s.pendingMu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
// ResolvePending sends the FileResult to the waiting channel identified by cmdID.
|
||||
func (r *Registry) ResolvePending(agentID, cmdID string, result *agentv1.FileResult) {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.pendingMu.Lock()
|
||||
ch, ok := s.pendingFiles[cmdID]
|
||||
if ok {
|
||||
delete(s.pendingFiles, cmdID)
|
||||
}
|
||||
s.pendingMu.Unlock()
|
||||
if ok {
|
||||
select {
|
||||
case ch <- result:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CancelPending removes the pending channel for cmdID (cleanup on timeout).
|
||||
func (r *Registry) CancelPending(agentID, cmdID string) {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.pendingMu.Lock()
|
||||
delete(s.pendingFiles, cmdID)
|
||||
s.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
// SendAndWait registers a pending channel, sends msg to the agent, and waits up
|
||||
// to 30 seconds for the FileResult response identified by cmdID.
|
||||
func (r *Registry) SendAndWait(agentID string, msg *agentv1.ServerMessage, cmdID string) (*agentv1.FileResult, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return r.SendAndWaitCtx(ctx, agentID, msg, cmdID)
|
||||
}
|
||||
|
||||
// RegisterPendingUpdate enregistre un commandID en attente de CommandResult pour un UpdateContainer.
|
||||
func (r *Registry) RegisterPendingUpdate(agentID, cmdID, containerID string) {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.pendingMu.Lock()
|
||||
s.pendingUpdates[cmdID] = containerID
|
||||
s.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
// ResolvePendingUpdate retourne le containerID associé au commandID et le supprime de la map.
|
||||
// Retourne ("", false) si le commandID n'est pas connu.
|
||||
func (r *Registry) ResolvePendingUpdate(agentID, cmdID string) (string, bool) {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s.pendingMu.Lock()
|
||||
containerID, found := s.pendingUpdates[cmdID]
|
||||
if found {
|
||||
delete(s.pendingUpdates, cmdID)
|
||||
}
|
||||
s.pendingMu.Unlock()
|
||||
return containerID, found
|
||||
}
|
||||
|
||||
// SendAndWaitCtx is like SendAndWait but uses the provided context for timeout control.
|
||||
func (r *Registry) SendAndWaitCtx(ctx context.Context, agentID string, msg *agentv1.ServerMessage, cmdID string) (*agentv1.FileResult, error) {
|
||||
ch := r.RegisterPending(agentID, cmdID)
|
||||
if ch == nil {
|
||||
return nil, fmt.Errorf("agent not connected")
|
||||
}
|
||||
|
||||
if !r.Send(agentID, msg) {
|
||||
r.CancelPending(agentID, cmdID)
|
||||
return nil, fmt.Errorf("agent not connected")
|
||||
}
|
||||
|
||||
select {
|
||||
case result := <-ch:
|
||||
return result, nil
|
||||
case <-ctx.Done():
|
||||
r.CancelPending(agentID, cmdID)
|
||||
return nil, fmt.Errorf("timeout waiting for agent response")
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -153,3 +154,112 @@ func TestSend_FullChannel(t *testing.T) {
|
||||
t.Error("Send should return false when channel is full")
|
||||
}
|
||||
}
|
||||
|
||||
// ── Pending file correlations ──────────────────────────────────────────────────
|
||||
|
||||
func TestRegisterPending_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
ch := r.RegisterPending("ghost", "cmd1")
|
||||
if ch != nil {
|
||||
t.Error("expected nil channel for unknown agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePending_Success(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
ch := r.RegisterPending("id1", "cmd1")
|
||||
if ch == nil {
|
||||
t.Fatal("expected non-nil channel")
|
||||
}
|
||||
|
||||
result := &agentv1.FileResult{CommandId: "cmd1", Success: true, Content: []byte("data")}
|
||||
r.ResolvePending("id1", "cmd1", result)
|
||||
|
||||
select {
|
||||
case got := <-ch:
|
||||
if got.CommandId != "cmd1" || !got.Success {
|
||||
t.Errorf("unexpected result: %+v", got)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for resolve")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePending_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
// must not panic
|
||||
r.ResolvePending("ghost", "cmd1", &agentv1.FileResult{})
|
||||
}
|
||||
|
||||
func TestResolvePending_UnknownCmd(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
// must not panic
|
||||
r.ResolvePending("id1", "nonexistent", &agentv1.FileResult{})
|
||||
}
|
||||
|
||||
func TestCancelPending(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
r.RegisterPending("id1", "cmd1")
|
||||
r.CancelPending("id1", "cmd1")
|
||||
|
||||
// After cancel, resolving should be a no-op (not panic)
|
||||
r.ResolvePending("id1", "cmd1", &agentv1.FileResult{})
|
||||
}
|
||||
|
||||
func TestCancelPending_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
// must not panic
|
||||
r.CancelPending("ghost", "cmd1")
|
||||
}
|
||||
|
||||
func TestSendAndWaitCtx_AgentNotConnected(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
ctx := context.Background()
|
||||
_, err := r.SendAndWaitCtx(ctx, "ghost", &agentv1.ServerMessage{}, "cmd1")
|
||||
if err == nil || err.Error() != "agent not connected" {
|
||||
t.Errorf("expected 'agent not connected', got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendAndWaitCtx_Timeout(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
// Use an already-cancelled context to force immediate timeout.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancel immediately
|
||||
|
||||
_, err := r.SendAndWaitCtx(ctx, "id1", &agentv1.ServerMessage{}, "cmd-timeout")
|
||||
if err == nil {
|
||||
t.Error("expected timeout or not-connected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendAndWaitCtx_Success(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
cmdID := "cmd-success"
|
||||
expected := &agentv1.FileResult{CommandId: cmdID, Success: true, Content: []byte("hello")}
|
||||
|
||||
// Simulate the agent responding after the send.
|
||||
go func() {
|
||||
// Wait briefly for RegisterPending + Send to happen.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
r.ResolvePending("id1", cmdID, expected)
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := r.SendAndWaitCtx(ctx, "id1", &agentv1.ServerMessage{}, cmdID)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.CommandId != cmdID || !result.Success {
|
||||
t.Errorf("unexpected result: %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user