266 lines
6.3 KiB
Go
266 lines
6.3 KiB
Go
package grpc
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
|
)
|
|
|
|
func TestRegisterAndGet(t *testing.T) {
|
|
r := NewRegistry()
|
|
|
|
state := r.Register("id1", "hostname1", "alias1", "10.0.0.1", "amd64", "linux")
|
|
if state == nil {
|
|
t.Fatal("Register returned nil")
|
|
}
|
|
|
|
got, ok := r.Get("id1")
|
|
if !ok {
|
|
t.Fatal("Get returned false for registered agent")
|
|
}
|
|
if got.ID != "id1" || got.Hostname != "hostname1" || got.Alias != "alias1" {
|
|
t.Errorf("unexpected state: %+v", got)
|
|
}
|
|
}
|
|
|
|
func TestGet_NotFound(t *testing.T) {
|
|
r := NewRegistry()
|
|
_, ok := r.Get("nonexistent")
|
|
if ok {
|
|
t.Error("expected false for unknown agent")
|
|
}
|
|
}
|
|
|
|
func TestDeregister(t *testing.T) {
|
|
r := NewRegistry()
|
|
r.Register("id1", "h", "a", "ip", "arch", "os")
|
|
|
|
r.Deregister("id1")
|
|
|
|
_, ok := r.Get("id1")
|
|
if ok {
|
|
t.Error("agent should not exist after Deregister")
|
|
}
|
|
}
|
|
|
|
func TestDeregister_NotExist(t *testing.T) {
|
|
r := NewRegistry()
|
|
// must not panic
|
|
r.Deregister("ghost")
|
|
}
|
|
|
|
func TestList(t *testing.T) {
|
|
r := NewRegistry()
|
|
|
|
if len(r.List()) != 0 {
|
|
t.Error("expected empty list")
|
|
}
|
|
|
|
r.Register("a1", "h1", "", "", "", "")
|
|
r.Register("a2", "h2", "", "", "", "")
|
|
|
|
if len(r.List()) != 2 {
|
|
t.Errorf("expected 2 agents, got %d", len(r.List()))
|
|
}
|
|
}
|
|
|
|
func TestUpdateContainers(t *testing.T) {
|
|
r := NewRegistry()
|
|
r.Register("id1", "h", "a", "ip", "arch", "os")
|
|
|
|
before := time.Now()
|
|
containers := []*agentv1.ContainerInfo{
|
|
{Id: "c1", Name: "web"},
|
|
{Id: "c2", Name: "db"},
|
|
}
|
|
r.UpdateContainers("id1", containers)
|
|
|
|
got, _ := r.Get("id1")
|
|
if len(got.Containers) != 2 {
|
|
t.Errorf("expected 2 containers, got %d", len(got.Containers))
|
|
}
|
|
if got.LastSeenAt.Before(before) {
|
|
t.Error("LastSeenAt should have been updated")
|
|
}
|
|
}
|
|
|
|
func TestUpdateContainers_UnknownAgent(t *testing.T) {
|
|
r := NewRegistry()
|
|
// must not panic
|
|
r.UpdateContainers("ghost", nil)
|
|
}
|
|
|
|
func TestUpdateAlias(t *testing.T) {
|
|
r := NewRegistry()
|
|
r.Register("id1", "h", "old-alias", "ip", "arch", "os")
|
|
|
|
r.UpdateAlias("id1", "new-alias")
|
|
|
|
got, _ := r.Get("id1")
|
|
if got.Alias != "new-alias" {
|
|
t.Errorf("expected 'new-alias', got %q", got.Alias)
|
|
}
|
|
}
|
|
|
|
func TestUpdateAlias_UnknownAgent(t *testing.T) {
|
|
r := NewRegistry()
|
|
// must not panic
|
|
r.UpdateAlias("ghost", "alias")
|
|
}
|
|
|
|
func TestSend(t *testing.T) {
|
|
r := NewRegistry()
|
|
state := r.Register("id1", "h", "a", "ip", "arch", "os")
|
|
|
|
msg := &agentv1.ServerMessage{}
|
|
ok := r.Send("id1", msg)
|
|
if !ok {
|
|
t.Fatal("Send returned false for connected agent")
|
|
}
|
|
|
|
// Drain the channel to verify the message arrived.
|
|
select {
|
|
case got := <-state.cmdCh:
|
|
if got != msg {
|
|
t.Error("received wrong message")
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out reading from cmdCh")
|
|
}
|
|
}
|
|
|
|
func TestSend_UnknownAgent(t *testing.T) {
|
|
r := NewRegistry()
|
|
ok := r.Send("ghost", &agentv1.ServerMessage{})
|
|
if ok {
|
|
t.Error("Send should return false for unknown agent")
|
|
}
|
|
}
|
|
|
|
func TestSend_FullChannel(t *testing.T) {
|
|
r := NewRegistry()
|
|
r.Register("id1", "h", "a", "ip", "arch", "os")
|
|
|
|
// Fill the buffer (size 16)
|
|
for i := 0; i < 16; i++ {
|
|
r.Send("id1", &agentv1.ServerMessage{})
|
|
}
|
|
|
|
// Next send on a full channel should return false
|
|
ok := r.Send("id1", &agentv1.ServerMessage{})
|
|
if ok {
|
|
t.Error("Send should return false when channel is full")
|
|
}
|
|
}
|
|
|
|
// ── Pending file correlations ──────────────────────────────────────────────────
|
|
|
|
func TestRegisterPending_UnknownAgent(t *testing.T) {
|
|
r := NewRegistry()
|
|
ch := r.RegisterPending("ghost", "cmd1")
|
|
if ch != nil {
|
|
t.Error("expected nil channel for unknown agent")
|
|
}
|
|
}
|
|
|
|
func TestResolvePending_Success(t *testing.T) {
|
|
r := NewRegistry()
|
|
r.Register("id1", "h", "a", "ip", "arch", "os")
|
|
|
|
ch := r.RegisterPending("id1", "cmd1")
|
|
if ch == nil {
|
|
t.Fatal("expected non-nil channel")
|
|
}
|
|
|
|
result := &agentv1.FileResult{CommandId: "cmd1", Success: true, Content: []byte("data")}
|
|
r.ResolvePending("id1", "cmd1", result)
|
|
|
|
select {
|
|
case got := <-ch:
|
|
if got.CommandId != "cmd1" || !got.Success {
|
|
t.Errorf("unexpected result: %+v", got)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for resolve")
|
|
}
|
|
}
|
|
|
|
func TestResolvePending_UnknownAgent(t *testing.T) {
|
|
r := NewRegistry()
|
|
// must not panic
|
|
r.ResolvePending("ghost", "cmd1", &agentv1.FileResult{})
|
|
}
|
|
|
|
func TestResolvePending_UnknownCmd(t *testing.T) {
|
|
r := NewRegistry()
|
|
r.Register("id1", "h", "a", "ip", "arch", "os")
|
|
// must not panic
|
|
r.ResolvePending("id1", "nonexistent", &agentv1.FileResult{})
|
|
}
|
|
|
|
func TestCancelPending(t *testing.T) {
|
|
r := NewRegistry()
|
|
r.Register("id1", "h", "a", "ip", "arch", "os")
|
|
|
|
r.RegisterPending("id1", "cmd1")
|
|
r.CancelPending("id1", "cmd1")
|
|
|
|
// After cancel, resolving should be a no-op (not panic)
|
|
r.ResolvePending("id1", "cmd1", &agentv1.FileResult{})
|
|
}
|
|
|
|
func TestCancelPending_UnknownAgent(t *testing.T) {
|
|
r := NewRegistry()
|
|
// must not panic
|
|
r.CancelPending("ghost", "cmd1")
|
|
}
|
|
|
|
func TestSendAndWaitCtx_AgentNotConnected(t *testing.T) {
|
|
r := NewRegistry()
|
|
ctx := context.Background()
|
|
_, err := r.SendAndWaitCtx(ctx, "ghost", &agentv1.ServerMessage{}, "cmd1")
|
|
if err == nil || err.Error() != "agent not connected" {
|
|
t.Errorf("expected 'agent not connected', got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSendAndWaitCtx_Timeout(t *testing.T) {
|
|
r := NewRegistry()
|
|
r.Register("id1", "h", "a", "ip", "arch", "os")
|
|
|
|
// Use an already-cancelled context to force immediate timeout.
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel() // cancel immediately
|
|
|
|
_, err := r.SendAndWaitCtx(ctx, "id1", &agentv1.ServerMessage{}, "cmd-timeout")
|
|
if err == nil {
|
|
t.Error("expected timeout or not-connected error")
|
|
}
|
|
}
|
|
|
|
func TestSendAndWaitCtx_Success(t *testing.T) {
|
|
r := NewRegistry()
|
|
r.Register("id1", "h", "a", "ip", "arch", "os")
|
|
|
|
cmdID := "cmd-success"
|
|
expected := &agentv1.FileResult{CommandId: cmdID, Success: true, Content: []byte("hello")}
|
|
|
|
// Simulate the agent responding after the send.
|
|
go func() {
|
|
// Wait briefly for RegisterPending + Send to happen.
|
|
time.Sleep(10 * time.Millisecond)
|
|
r.ResolvePending("id1", cmdID, expected)
|
|
}()
|
|
|
|
ctx := context.Background()
|
|
result, err := r.SendAndWaitCtx(ctx, "id1", &agentv1.ServerMessage{}, cmdID)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if result.CommandId != cmdID || !result.Success {
|
|
t.Errorf("unexpected result: %+v", result)
|
|
}
|
|
}
|