/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.deploy;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.transport.deploy.MLDeployModelInput;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodeRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodeResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodesRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelNodesResponse;
import org.opensearch.ml.common.transport.forward.MLForwardInput;
import org.opensearch.ml.common.transport.forward.MLForwardRequest;
import org.opensearch.ml.common.transport.forward.MLForwardRequestType;
import org.opensearch.ml.common.transport.forward.MLForwardResponse;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

public class TransportDeployModelOnNodeAction
extends TransportNodesAction<MLDeployModelNodesRequest, MLDeployModelNodesResponse, MLDeployModelNodeRequest, MLDeployModelNodeResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportDeployModelOnNodeAction.class);
    TransportService transportService;
    ModelHelper modelHelper;
    MLTaskManager mlTaskManager;
    MLModelManager mlModelManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    MLCircuitBreakerService mlCircuitBreakerService;
    MLStats mlStats;

    @Inject
    public TransportDeployModelOnNodeAction(TransportService transportService, ActionFilters actionFilters, ModelHelper modelHelper, MLTaskManager mlTaskManager, MLModelManager mlModelManager, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry xContentRegistry, MLCircuitBreakerService mlCircuitBreakerService, MLStats mlStats) {
        super("cluster:admin/opensearch/ml/deploy_model_on_nodes", threadPool, clusterService, transportService, actionFilters, MLDeployModelNodesRequest::new, MLDeployModelNodeRequest::new, "management", MLDeployModelNodeResponse.class);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlTaskManager = mlTaskManager;
        this.mlModelManager = mlModelManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.mlCircuitBreakerService = mlCircuitBreakerService;
        this.mlStats = mlStats;
    }

    protected MLDeployModelNodesResponse newResponse(MLDeployModelNodesRequest nodesRequest, List<MLDeployModelNodeResponse> responses, List<FailedNodeException> failures) {
        return new MLDeployModelNodesResponse(this.clusterService.getClusterName(), responses, failures);
    }

    protected MLDeployModelNodeRequest newNodeRequest(MLDeployModelNodesRequest request) {
        return new MLDeployModelNodeRequest(request);
    }

    protected MLDeployModelNodeResponse newNodeResponse(StreamInput in) throws IOException {
        return new MLDeployModelNodeResponse(in);
    }

    protected MLDeployModelNodeResponse nodeOperation(MLDeployModelNodeRequest request) {
        return this.createDeployModelNodeResponse(request.getMLDeployModelNodesRequest());
    }

    private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNodesRequest MLDeployModelNodesRequest2) {
        MLDeployModelInput deployModelInput = MLDeployModelNodesRequest2.getMlDeployModelInput();
        String modelId = deployModelInput.getModelId();
        String taskId = deployModelInput.getTaskId();
        String coordinatingNodeId = deployModelInput.getCoordinatingNodeId();
        MLTask mlTask = deployModelInput.getMlTask();
        String modelContentHash = deployModelInput.getModelContentHash();
        boolean deployToAllNodes = deployModelInput.getIsDeployToAllNodes();
        HashMap<String, String> modelDeployStatus = new HashMap<String, String>();
        modelDeployStatus.put(modelId, "received");
        String localNodeId = this.clusterService.localNode().getId();
        ActionListener taskDoneListener = ActionListener.wrap(res -> log.info("deploy model task done " + taskId), ex -> MLExceptionUtils.logException("Deploy model task failed: " + taskId, ex, log));
        this.deployModel(modelId, modelContentHash, mlTask.getFunctionName(), localNodeId, coordinatingNodeId, deployToAllNodes, mlTask, (ActionListener<String>)ActionListener.wrap(r -> {
            MLForwardInput mlForwardInput = MLForwardInput.builder().requestType(MLForwardRequestType.DEPLOY_MODEL_DONE).taskId(taskId).modelId(modelId).workerNodeId(this.clusterService.localNode().getId()).build();
            MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput);
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.transportService.sendRequest(this.getNodeById(coordinatingNodeId), "cluster:admin/opensearch/mlinternal/forward", (TransportRequest)deployModelDoneMessage, (TransportResponseHandler)new ActionListenerResponseHandler(taskDoneListener, MLForwardResponse::new));
            }
        }, e -> {
            MLForwardInput mlForwardInput = MLForwardInput.builder().requestType(MLForwardRequestType.DEPLOY_MODEL_DONE).taskId(taskId).modelId(modelId).workerNodeId(this.clusterService.localNode().getId()).error(MLExceptionUtils.getRootCauseMessage(e)).build();
            MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput);
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.transportService.sendRequest(this.getNodeById(coordinatingNodeId), "cluster:admin/opensearch/mlinternal/forward", (TransportRequest)deployModelDoneMessage, (TransportResponseHandler)new ActionListenerResponseHandler(taskDoneListener, MLForwardResponse::new));
            }
        }));
        return new MLDeployModelNodeResponse(this.clusterService.localNode(), modelDeployStatus);
    }

    private DiscoveryNode getNodeById(String nodeId) {
        DiscoveryNodes nodes = this.clusterService.state().getNodes();
        for (DiscoveryNode node : nodes) {
            if (!node.getId().equals(nodeId)) continue;
            return node;
        }
        return null;
    }

    private void deployModel(String modelId, String modelContentHash, FunctionName functionName, String localNodeId, String coordinatingNodeId, boolean deployToAllNodes, MLTask mlTask, ActionListener<String> listener) {
        try {
            log.debug("start deploying model {}", (Object)modelId);
            this.mlModelManager.deployModel(modelId, modelContentHash, functionName, deployToAllNodes, false, mlTask, (ActionListener<String>)ActionListener.runBefore(listener, () -> {
                if (!coordinatingNodeId.equals(localNodeId)) {
                    this.mlTaskManager.remove(mlTask.getTaskId());
                }
            }));
        }
        catch (Exception e) {
            MLExceptionUtils.logException("Failed to deploy model " + modelId, e, log);
            listener.onFailure(e);
        }
    }
}

