/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hertzbeat.ai.agent.config;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.usthe.sureness.mgt.SurenessSecurityManager;
import com.usthe.sureness.subject.SubjectSum;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.hertzbeat.ai.agent.config.McpContextHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.RouterFunctions;
import org.springframework.web.servlet.function.ServerRequest;
import org.springframework.web.servlet.function.ServerResponse;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class CustomSseServerTransport
implements McpServerTransportProvider {
    private static final Logger log = LoggerFactory.getLogger(CustomSseServerTransport.class);
    private final ObjectMapper objectMapper;
    private final String messageEndpoint;
    private final String sseEndpoint;
    private final String baseUrl;
    private final RouterFunction<ServerResponse> routerFunction;
    private McpServerSession.Factory sessionFactory;
    private final Map<String, Object> sessionRequest = new HashMap<String, Object>();
    private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap();
    private volatile boolean isClosing = false;

    public CustomSseServerTransport(ObjectMapper objectMapper, String messageEndpoint) {
        this(objectMapper, messageEndpoint, "/sse");
    }

    public CustomSseServerTransport(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
        this(objectMapper, "", messageEndpoint, sseEndpoint);
    }

    public CustomSseServerTransport(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) {
        Assert.notNull((Object)objectMapper, (String)"ObjectMapper must not be null");
        Assert.notNull((Object)baseUrl, (String)"Message base URL must not be null");
        Assert.notNull((Object)messageEndpoint, (String)"Message endpoint must not be null");
        Assert.notNull((Object)sseEndpoint, (String)"SSE endpoint must not be null");
        this.objectMapper = objectMapper;
        this.baseUrl = baseUrl;
        this.messageEndpoint = messageEndpoint;
        this.sseEndpoint = sseEndpoint;
        this.routerFunction = RouterFunctions.route().GET(this.sseEndpoint, this::handleSseConnection).POST(this.messageEndpoint, this::handleMessage).build();
    }

    public Mono<Void> notifyClients(String method, Object params) {
        if (this.sessions.isEmpty()) {
            log.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        log.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
        return Flux.fromIterable(this.sessions.values()).flatMap(session -> session.sendNotification(method, params).doOnError(e -> log.error("Failed to send message to session {}: {}", (Object)session.getId(), (Object)e.getMessage())).onErrorComplete()).then();
    }

    public Mono<Void> closeGracefully() {
        return Flux.fromIterable(this.sessions.values()).doFirst(() -> {
            this.isClosing = true;
            log.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size());
        }).flatMap(McpServerSession::closeGracefully).then().doOnSuccess(v -> log.debug("Graceful shutdown completed"));
    }

    private ServerResponse handleSseConnection(ServerRequest request) {
        log.debug("Handling SSE connection for request: {}", (Object)request);
        HttpServletRequest servletRequest = request.servletRequest();
        try {
            log.debug("Processing SSE connection for servlet request: {}", (Object)servletRequest);
            log.debug("Authorization header: {}", (Object)servletRequest.getHeader("Authorization"));
        }
        catch (Exception e) {
            log.error("Authentication failed for SSE connection: {}", (Object)e.getMessage());
            return ServerResponse.status((HttpStatusCode)HttpStatus.UNAUTHORIZED).body((Object)("Unauthorized: " + e.getMessage()));
        }
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).body((Object)"Server is shutting down");
        }
        String sessionId = UUID.randomUUID().toString();
        log.debug("Generated session ID for SSE connection: {}", (Object)sessionId);
        log.debug("Creating new SSE connection for session: {}", (Object)sessionId);
        return ServerResponse.sse(sseBuilder -> {
            sseBuilder.onComplete(() -> {
                log.debug("SSE connection completed for session: {}", (Object)sessionId);
                this.sessions.remove(sessionId);
            });
            sseBuilder.onTimeout(() -> {
                log.debug("SSE connection timed out for session: {}", (Object)sessionId);
                this.sessions.remove(sessionId);
            });
            WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, (ServerResponse.SseBuilder)sseBuilder);
            McpServerSession session = this.sessionFactory.create((McpServerTransport)sessionTransport);
            this.sessionRequest.put(sessionId, request.servletRequest());
            this.sessions.put(sessionId, session);
            try {
                sseBuilder.id(sessionId).event("endpoint").data((Object)(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId));
            }
            catch (Exception e) {
                log.error("Failed to send initial endpoint event: {}", (Object)e.getMessage());
                sseBuilder.error((Throwable)e);
            }
        }, (Duration)Duration.ZERO);
    }

    private ServerResponse handleMessage(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).body((Object)"Server is shutting down");
        }
        if (request.param("sessionId").isEmpty()) {
            return ServerResponse.badRequest().body((Object)new McpError((Object)"Session ID missing in message endpoint"));
        }
        String sessionId = (String)request.param("sessionId").get();
        McpServerSession session = this.sessions.get(sessionId);
        log.debug("Authorization header for message request: {}", (Object)request.servletRequest().getHeader("Authorization"));
        SubjectSum subject = SurenessSecurityManager.getInstance().checkIn(this.sessionRequest.get(sessionId));
        McpContextHolder.setSubject(subject);
        if (session == null) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.NOT_FOUND).body((Object)new McpError((Object)("Session not found: " + sessionId)));
        }
        try {
            String body = (String)request.body(String.class);
            McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage((ObjectMapper)this.objectMapper, (String)body);
            session.handle(message).block();
            return ServerResponse.ok().build();
        }
        catch (IOException | IllegalArgumentException e) {
            log.error("Failed to deserialize message: {}", (Object)e.getMessage());
            return ServerResponse.badRequest().body((Object)new McpError((Object)"Invalid message format"));
        }
        catch (Exception e) {
            log.error("Error handling message: {}", (Object)e.getMessage());
            return ServerResponse.status((HttpStatusCode)HttpStatus.INTERNAL_SERVER_ERROR).body((Object)new McpError((Object)e.getMessage()));
        }
    }

    public RouterFunction<ServerResponse> getRouterFunction() {
        return this.routerFunction;
    }

    public void setSessionFactory(McpServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    private class WebMvcMcpSessionTransport
    implements McpServerTransport {
        private final String sessionId;
        private final ServerResponse.SseBuilder sseBuilder;

        WebMvcMcpSessionTransport(String sessionId, ServerResponse.SseBuilder sseBuilder) {
            this.sessionId = sessionId;
            this.sseBuilder = sseBuilder;
            log.debug("Session transport {} initialized with SSE builder", (Object)sessionId);
        }

        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return Mono.fromRunnable(() -> {
                try {
                    String jsonText = CustomSseServerTransport.this.objectMapper.writeValueAsString((Object)message);
                    this.sseBuilder.id(this.sessionId).event("message").data((Object)jsonText);
                    log.debug("Message sent to session {}", (Object)this.sessionId);
                }
                catch (Exception e) {
                    log.error("Failed to send message to session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
                    this.sseBuilder.error((Throwable)e);
                }
            });
        }

        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return (T)CustomSseServerTransport.this.objectMapper.convertValue(data, typeRef);
        }

        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> {
                log.debug("Closing session transport: {}", (Object)this.sessionId);
                try {
                    this.sseBuilder.complete();
                    log.debug("Successfully completed SSE builder for session {}", (Object)this.sessionId);
                }
                catch (Exception e) {
                    log.warn("Failed to complete SSE builder for session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
                }
            });
        }

        public void close() {
            try {
                this.sseBuilder.complete();
                log.debug("Successfully completed SSE builder for session {}", (Object)this.sessionId);
            }
            catch (Exception e) {
                log.warn("Failed to complete SSE builder for session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
            }
        }
    }
}

