package grpc import ( "io" "log/slog" "net" "github.com/containarr/server/internal/broker" agentv1 "github.com/containarr/server/internal/proto/agentv1" "github.com/containarr/server/internal/store" "github.com/google/uuid" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" ) type Gateway struct { agentv1.UnimplementedAgentGatewayServer store *store.Store registry *Registry broker *broker.Broker } func NewGateway(s *store.Store, r *Registry, b *broker.Broker) *Gateway { return &Gateway{store: s, registry: r, broker: b} } func (g *Gateway) Tunnel(stream agentv1.AgentGateway_TunnelServer) error { if err := stream.SendHeader(metadata.MD{}); err != nil { return status.Errorf(codes.Internal, "send header: %v", err) } first, err := stream.Recv() if err != nil { return err } hs := first.GetHandshake() if hs == nil { return status.Error(codes.InvalidArgument, "first message must be AgentHandshake") } existing, err := g.store.AgentByToken(hs.Token) if err != nil { return status.Error(codes.Unauthenticated, "unknown agent token") } // Extract peer IP from the gRPC connection. ipAddress := "" if p, ok := peer.FromContext(stream.Context()); ok { if host, _, err := net.SplitHostPort(p.Addr.String()); err == nil { ipAddress = host } } agentID := existing.ID slog.Info("agent connected", "id", agentID, "hostname", hs.Hostname, "ip", ipAddress) state := g.registry.Register(agentID, hs.Hostname, existing.Alias, ipAddress, hs.Arch, hs.Os) _ = g.store.UpsertAgent(&store.Agent{ ID: agentID, Token: hs.Token, Hostname: hs.Hostname, Alias: existing.Alias, IPAddress: ipAddress, Arch: hs.Arch, OS: hs.Os, Online: true, }) g.broker.Publish(broker.Event{ Type: "agent.connected", AgentID: agentID, Payload: map[string]string{"hostname": hs.Hostname}, }) defer func() { g.registry.Deregister(agentID) _ = g.store.SetAgentOffline(agentID) g.broker.Publish(broker.Event{Type: "agent.disconnected", AgentID: agentID, Payload: nil}) slog.Info("agent disconnected", "id", agentID) }() errCh := make(chan error, 1) go func() { for msg := range state.cmdCh { if err := stream.Send(msg); err != nil { errCh <- err return } } }() for { select { case err := <-errCh: return err default: } msg, err := stream.Recv() if err == io.EOF { return nil } if err != nil { return err } switch p := msg.Payload.(type) { case *agentv1.AgentMessage_Snapshot: g.registry.UpdateContainers(agentID, p.Snapshot.Containers) g.broker.Publish(broker.Event{ Type: "containers.updated", AgentID: agentID, Payload: p.Snapshot.Containers, }) case *agentv1.AgentMessage_Result: g.broker.Publish(broker.Event{ Type: "command.result", AgentID: agentID, Payload: p.Result, }) case *agentv1.AgentMessage_LogChunk: g.broker.Publish(broker.Event{ Type: "log.chunk", AgentID: agentID, Payload: p.LogChunk, }) } } } func newCommandID() string { return uuid.NewString() }