Files
Containarr/server/internal/grpc/gateway.go

138 lines
3.1 KiB
Go

package grpc
import (
"io"
"log/slog"
"net"
"github.com/containarr/server/internal/broker"
agentv1 "github.com/containarr/server/internal/proto/agentv1"
"github.com/containarr/server/internal/store"
"github.com/google/uuid"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
type Gateway struct {
agentv1.UnimplementedAgentGatewayServer
store *store.Store
registry *Registry
broker *broker.Broker
}
func NewGateway(s *store.Store, r *Registry, b *broker.Broker) *Gateway {
return &Gateway{store: s, registry: r, broker: b}
}
func (g *Gateway) Tunnel(stream agentv1.AgentGateway_TunnelServer) error {
if err := stream.SendHeader(metadata.MD{}); err != nil {
return status.Errorf(codes.Internal, "send header: %v", err)
}
first, err := stream.Recv()
if err != nil {
return err
}
hs := first.GetHandshake()
if hs == nil {
return status.Error(codes.InvalidArgument, "first message must be AgentHandshake")
}
existing, err := g.store.AgentByToken(hs.Token)
if err != nil {
return status.Error(codes.Unauthenticated, "unknown agent token")
}
// Extract peer IP from the gRPC connection.
ipAddress := ""
if p, ok := peer.FromContext(stream.Context()); ok {
if host, _, err := net.SplitHostPort(p.Addr.String()); err == nil {
ipAddress = host
}
}
agentID := existing.ID
slog.Info("agent connected", "id", agentID, "hostname", hs.Hostname, "ip", ipAddress)
state := g.registry.Register(agentID, hs.Hostname, existing.Alias, ipAddress, hs.Arch, hs.Os)
_ = g.store.UpsertAgent(&store.Agent{
ID: agentID,
Token: hs.Token,
Hostname: hs.Hostname,
Alias: existing.Alias,
IPAddress: ipAddress,
Arch: hs.Arch,
OS: hs.Os,
Online: true,
})
g.broker.Publish(broker.Event{
Type: "agent.connected",
AgentID: agentID,
Payload: map[string]string{"hostname": hs.Hostname},
})
defer func() {
g.registry.Deregister(agentID)
_ = g.store.SetAgentOffline(agentID)
g.broker.Publish(broker.Event{Type: "agent.disconnected", AgentID: agentID, Payload: nil})
slog.Info("agent disconnected", "id", agentID)
}()
errCh := make(chan error, 1)
go func() {
for msg := range state.cmdCh {
if err := stream.Send(msg); err != nil {
errCh <- err
return
}
}
}()
for {
select {
case err := <-errCh:
return err
default:
}
msg, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
switch p := msg.Payload.(type) {
case *agentv1.AgentMessage_Snapshot:
g.registry.UpdateContainers(agentID, p.Snapshot.Containers)
g.broker.Publish(broker.Event{
Type: "containers.updated",
AgentID: agentID,
Payload: p.Snapshot.Containers,
})
case *agentv1.AgentMessage_Result:
g.broker.Publish(broker.Event{
Type: "command.result",
AgentID: agentID,
Payload: p.Result,
})
case *agentv1.AgentMessage_LogChunk:
g.broker.Publish(broker.Event{
Type: "log.chunk",
AgentID: agentID,
Payload: p.LogChunk,
})
}
}
}
func newCommandID() string {
return uuid.NewString()
}