feat: add first page with auth and containers list and agents
This commit is contained in:
137
server/internal/grpc/gateway.go
Normal file
137
server/internal/grpc/gateway.go
Normal file
@ -0,0 +1,137 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
|
||||
"github.com/containarr/server/internal/broker"
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
"github.com/containarr/server/internal/store"
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type Gateway struct {
|
||||
agentv1.UnimplementedAgentGatewayServer
|
||||
store *store.Store
|
||||
registry *Registry
|
||||
broker *broker.Broker
|
||||
}
|
||||
|
||||
func NewGateway(s *store.Store, r *Registry, b *broker.Broker) *Gateway {
|
||||
return &Gateway{store: s, registry: r, broker: b}
|
||||
}
|
||||
|
||||
func (g *Gateway) Tunnel(stream agentv1.AgentGateway_TunnelServer) error {
|
||||
if err := stream.SendHeader(metadata.MD{}); err != nil {
|
||||
return status.Errorf(codes.Internal, "send header: %v", err)
|
||||
}
|
||||
|
||||
first, err := stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hs := first.GetHandshake()
|
||||
if hs == nil {
|
||||
return status.Error(codes.InvalidArgument, "first message must be AgentHandshake")
|
||||
}
|
||||
|
||||
existing, err := g.store.AgentByToken(hs.Token)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "unknown agent token")
|
||||
}
|
||||
|
||||
// Extract peer IP from the gRPC connection.
|
||||
ipAddress := ""
|
||||
if p, ok := peer.FromContext(stream.Context()); ok {
|
||||
if host, _, err := net.SplitHostPort(p.Addr.String()); err == nil {
|
||||
ipAddress = host
|
||||
}
|
||||
}
|
||||
|
||||
agentID := existing.ID
|
||||
slog.Info("agent connected", "id", agentID, "hostname", hs.Hostname, "ip", ipAddress)
|
||||
|
||||
state := g.registry.Register(agentID, hs.Hostname, existing.Alias, ipAddress, hs.Arch, hs.Os)
|
||||
_ = g.store.UpsertAgent(&store.Agent{
|
||||
ID: agentID,
|
||||
Token: hs.Token,
|
||||
Hostname: hs.Hostname,
|
||||
Alias: existing.Alias,
|
||||
IPAddress: ipAddress,
|
||||
Arch: hs.Arch,
|
||||
OS: hs.Os,
|
||||
Online: true,
|
||||
})
|
||||
|
||||
g.broker.Publish(broker.Event{
|
||||
Type: "agent.connected",
|
||||
AgentID: agentID,
|
||||
Payload: map[string]string{"hostname": hs.Hostname},
|
||||
})
|
||||
|
||||
defer func() {
|
||||
g.registry.Deregister(agentID)
|
||||
_ = g.store.SetAgentOffline(agentID)
|
||||
g.broker.Publish(broker.Event{Type: "agent.disconnected", AgentID: agentID, Payload: nil})
|
||||
slog.Info("agent disconnected", "id", agentID)
|
||||
}()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
for msg := range state.cmdCh {
|
||||
if err := stream.Send(msg); err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
default:
|
||||
}
|
||||
|
||||
msg, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch p := msg.Payload.(type) {
|
||||
case *agentv1.AgentMessage_Snapshot:
|
||||
g.registry.UpdateContainers(agentID, p.Snapshot.Containers)
|
||||
g.broker.Publish(broker.Event{
|
||||
Type: "containers.updated",
|
||||
AgentID: agentID,
|
||||
Payload: p.Snapshot.Containers,
|
||||
})
|
||||
|
||||
case *agentv1.AgentMessage_Result:
|
||||
g.broker.Publish(broker.Event{
|
||||
Type: "command.result",
|
||||
AgentID: agentID,
|
||||
Payload: p.Result,
|
||||
})
|
||||
|
||||
case *agentv1.AgentMessage_LogChunk:
|
||||
g.broker.Publish(broker.Event{
|
||||
Type: "log.chunk",
|
||||
AgentID: agentID,
|
||||
Payload: p.LogChunk,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newCommandID() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
105
server/internal/grpc/registry.go
Normal file
105
server/internal/grpc/registry.go
Normal file
@ -0,0 +1,105 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
)
|
||||
|
||||
type AgentState struct {
|
||||
ID string
|
||||
Hostname string
|
||||
Alias string
|
||||
IPAddress string
|
||||
Arch string
|
||||
OS string
|
||||
LastSeenAt time.Time
|
||||
Containers []*agentv1.ContainerInfo
|
||||
|
||||
cmdCh chan *agentv1.ServerMessage
|
||||
}
|
||||
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
agents map[string]*AgentState
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{agents: make(map[string]*AgentState)}
|
||||
}
|
||||
|
||||
func (r *Registry) Register(id, hostname, alias, ipAddress, arch, os string) *AgentState {
|
||||
state := &AgentState{
|
||||
ID: id,
|
||||
Hostname: hostname,
|
||||
Alias: alias,
|
||||
IPAddress: ipAddress,
|
||||
Arch: arch,
|
||||
OS: os,
|
||||
cmdCh: make(chan *agentv1.ServerMessage, 16),
|
||||
}
|
||||
r.mu.Lock()
|
||||
r.agents[id] = state
|
||||
r.mu.Unlock()
|
||||
return state
|
||||
}
|
||||
|
||||
func (r *Registry) Deregister(id string) {
|
||||
r.mu.Lock()
|
||||
if s, ok := r.agents[id]; ok {
|
||||
close(s.cmdCh)
|
||||
delete(r.agents, id)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *Registry) Get(id string) (*AgentState, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
s, ok := r.agents[id]
|
||||
return s, ok
|
||||
}
|
||||
|
||||
func (r *Registry) List() []*AgentState {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
out := make([]*AgentState, 0, len(r.agents))
|
||||
for _, s := range r.agents {
|
||||
out = append(out, s)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *Registry) UpdateContainers(id string, containers []*agentv1.ContainerInfo) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if s, ok := r.agents[id]; ok {
|
||||
s.Containers = containers
|
||||
s.LastSeenAt = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateAlias refreshes the alias for a live agent (called after an admin update).
|
||||
func (r *Registry) UpdateAlias(id, alias string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if s, ok := r.agents[id]; ok {
|
||||
s.Alias = alias
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) Send(agentID string, msg *agentv1.ServerMessage) bool {
|
||||
r.mu.RLock()
|
||||
s, ok := r.agents[agentID]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case s.cmdCh <- msg:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
155
server/internal/grpc/registry_test.go
Normal file
155
server/internal/grpc/registry_test.go
Normal file
@ -0,0 +1,155 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
agentv1 "github.com/containarr/server/internal/proto/agentv1"
|
||||
)
|
||||
|
||||
func TestRegisterAndGet(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
state := r.Register("id1", "hostname1", "alias1", "10.0.0.1", "amd64", "linux")
|
||||
if state == nil {
|
||||
t.Fatal("Register returned nil")
|
||||
}
|
||||
|
||||
got, ok := r.Get("id1")
|
||||
if !ok {
|
||||
t.Fatal("Get returned false for registered agent")
|
||||
}
|
||||
if got.ID != "id1" || got.Hostname != "hostname1" || got.Alias != "alias1" {
|
||||
t.Errorf("unexpected state: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet_NotFound(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_, ok := r.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("expected false for unknown agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeregister(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
r.Deregister("id1")
|
||||
|
||||
_, ok := r.Get("id1")
|
||||
if ok {
|
||||
t.Error("agent should not exist after Deregister")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeregister_NotExist(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
// must not panic
|
||||
r.Deregister("ghost")
|
||||
}
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
if len(r.List()) != 0 {
|
||||
t.Error("expected empty list")
|
||||
}
|
||||
|
||||
r.Register("a1", "h1", "", "", "", "")
|
||||
r.Register("a2", "h2", "", "", "", "")
|
||||
|
||||
if len(r.List()) != 2 {
|
||||
t.Errorf("expected 2 agents, got %d", len(r.List()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateContainers(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
before := time.Now()
|
||||
containers := []*agentv1.ContainerInfo{
|
||||
{Id: "c1", Name: "web"},
|
||||
{Id: "c2", Name: "db"},
|
||||
}
|
||||
r.UpdateContainers("id1", containers)
|
||||
|
||||
got, _ := r.Get("id1")
|
||||
if len(got.Containers) != 2 {
|
||||
t.Errorf("expected 2 containers, got %d", len(got.Containers))
|
||||
}
|
||||
if got.LastSeenAt.Before(before) {
|
||||
t.Error("LastSeenAt should have been updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateContainers_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
// must not panic
|
||||
r.UpdateContainers("ghost", nil)
|
||||
}
|
||||
|
||||
func TestUpdateAlias(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "old-alias", "ip", "arch", "os")
|
||||
|
||||
r.UpdateAlias("id1", "new-alias")
|
||||
|
||||
got, _ := r.Get("id1")
|
||||
if got.Alias != "new-alias" {
|
||||
t.Errorf("expected 'new-alias', got %q", got.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAlias_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
// must not panic
|
||||
r.UpdateAlias("ghost", "alias")
|
||||
}
|
||||
|
||||
func TestSend(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
state := r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
msg := &agentv1.ServerMessage{}
|
||||
ok := r.Send("id1", msg)
|
||||
if !ok {
|
||||
t.Fatal("Send returned false for connected agent")
|
||||
}
|
||||
|
||||
// Drain the channel to verify the message arrived.
|
||||
select {
|
||||
case got := <-state.cmdCh:
|
||||
if got != msg {
|
||||
t.Error("received wrong message")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out reading from cmdCh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_UnknownAgent(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
ok := r.Send("ghost", &agentv1.ServerMessage{})
|
||||
if ok {
|
||||
t.Error("Send should return false for unknown agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_FullChannel(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("id1", "h", "a", "ip", "arch", "os")
|
||||
|
||||
// Fill the buffer (size 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
r.Send("id1", &agentv1.ServerMessage{})
|
||||
}
|
||||
|
||||
// Next send on a full channel should return false
|
||||
ok := r.Send("id1", &agentv1.ServerMessage{})
|
||||
if ok {
|
||||
t.Error("Send should return false when channel is full")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user