/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.flowframework.workflow;

import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.common.WorkflowResources;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.flowframework.workflow.AbstractRetryableWorkflowStep;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.threadpool.ThreadPool;

public abstract class AbstractRegisterLocalModelStep
extends AbstractRetryableWorkflowStep {
    private static final Logger logger = LogManager.getLogger(AbstractRegisterLocalModelStep.class);
    private final MachineLearningNodeClient mlClient;
    private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;

    protected AbstractRegisterLocalModelStep(ThreadPool threadPool, MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, FlowFrameworkSettings flowFrameworkSettings) {
        super(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings);
        this.mlClient = mlClient;
        this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
    }

    @Override
    public PlainActionFuture<WorkflowData> execute(String currentNodeId, WorkflowData currentNodeInputs, Map<String, WorkflowData> outputs, Map<String, String> previousNodeInputs, Map<String, String> params) {
        PlainActionFuture registerLocalModelFuture = PlainActionFuture.newFuture();
        try {
            Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(this.getRequiredKeys(), this.getOptionalKeys(), currentNodeInputs, outputs, previousNodeInputs, params);
            String modelName = (String)inputs.get("name");
            String modelVersion = (String)inputs.get("version");
            String modelFormat = (String)inputs.get("model_format");
            String functionName = (String)inputs.get("function_name");
            String modelContentHashValue = (String)inputs.get("model_content_hash_value");
            String url = (String)inputs.get("url");
            String modelType = (String)inputs.get("model_type");
            String embeddingDimension = (String)inputs.get("embedding_dimension");
            String frameworkType = (String)inputs.get("framework_type");
            String description = (String)inputs.get("description");
            String modelGroupId = (String)inputs.get("model_group_id");
            String allConfig = (String)inputs.get("all_config");
            String modelInterface = (String)inputs.get("interface");
            Boolean deploy = ParseUtils.parseIfExists(inputs, "deploy", Boolean.class);
            MLRegisterModelInput.MLRegisterModelInputBuilder mlInputBuilder = MLRegisterModelInput.builder().modelName(modelName).version(modelVersion).modelFormat(MLModelFormat.from((String)modelFormat));
            if (functionName != null) {
                mlInputBuilder.functionName(FunctionName.from((String)functionName));
            }
            if (modelContentHashValue != null) {
                mlInputBuilder.hashValue(modelContentHashValue);
            }
            if (url != null) {
                mlInputBuilder.url(url);
            }
            if (Stream.of(modelType, embeddingDimension, frameworkType).allMatch(x -> x != null)) {
                TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder mlConfigBuilder = TextEmbeddingModelConfig.builder().modelType(modelType).embeddingDimension(Integer.valueOf(embeddingDimension)).frameworkType(TextEmbeddingModelConfig.FrameworkType.from((String)frameworkType));
                if (allConfig != null) {
                    mlConfigBuilder.allConfig(allConfig);
                }
                mlInputBuilder.modelConfig((MLModelConfig)mlConfigBuilder.build());
            }
            if (description != null) {
                mlInputBuilder.description(description);
            }
            if (modelGroupId != null) {
                mlInputBuilder.modelGroupId(modelGroupId);
            }
            if (modelInterface != null) {
                try {
                    BytesArray modelInterfaceBytes = new BytesArray(modelInterface.getBytes(StandardCharsets.UTF_8));
                    Map modelInterfaceAsMap = (Map)XContentHelper.convertToMap((BytesReference)modelInterfaceBytes, (boolean)false, (MediaType)MediaTypeRegistry.JSON).v2();
                    Map<String, String> parameters = ParseUtils.convertStringToObjectMapToStringToStringMap(modelInterfaceAsMap);
                    mlInputBuilder.modelInterface(parameters);
                }
                catch (Exception ex) {
                    String errorMessage = "Failed to create model interface";
                    logger.error(errorMessage, (Throwable)ex);
                    registerLocalModelFuture.onFailure((Exception)((Object)new WorkflowStepException(errorMessage, RestStatus.BAD_REQUEST)));
                }
            }
            if (deploy != null) {
                mlInputBuilder.deployModel(deploy.booleanValue());
            }
            MLRegisterModelInput mlInput = mlInputBuilder.build();
            this.mlClient.register(mlInput, ActionListener.wrap(response -> {
                logger.info("Local Model registration task creation successful");
                String taskId = response.getTaskId();
                this.retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, (PlainActionFuture<WorkflowData>)registerLocalModelFuture, taskId, "Local model registration", (ActionListener<MLTask>)ActionListener.wrap(mlTask -> {
                    String resourceName = WorkflowResources.getResourceByWorkflowStep(this.getName());
                    String id = this.getResourceId((MLTask)mlTask);
                    if (Boolean.TRUE.equals(deploy)) {
                        this.flowFrameworkIndicesHandler.updateResourceInStateIndex(currentNodeInputs.getWorkflowId(), currentNodeId, "deploy_model", id, (ActionListener<UpdateResponse>)ActionListener.wrap(deployUpdateResponse -> {
                            logger.info("successfully updated resources created in state index: {}", (Object)deployUpdateResponse.getIndex());
                            registerLocalModelFuture.onResponse((Object)new WorkflowData(Map.ofEntries(Map.entry(resourceName, id), Map.entry("register_model_status", mlTask.getState().name())), currentNodeInputs.getWorkflowId(), currentNodeId));
                        }, deployUpdateException -> {
                            String errorMessage = "Failed to update simulated deploy step resource " + id;
                            logger.error(errorMessage, (Throwable)deployUpdateException);
                            registerLocalModelFuture.onFailure((Exception)((Object)new FlowFrameworkException(errorMessage, ExceptionsHelper.status((Throwable)deployUpdateException))));
                        }));
                    } else {
                        registerLocalModelFuture.onResponse((Object)new WorkflowData(Map.ofEntries(Map.entry(resourceName, id), Map.entry("register_model_status", mlTask.getState().name())), currentNodeInputs.getWorkflowId(), currentNodeId));
                    }
                }, exception -> registerLocalModelFuture.onFailure(exception)));
            }, exception -> {
                Exception e = WorkflowStepException.getSafeException(exception);
                Object errorMessage = e == null ? "Failed to register local model in step " + currentNodeId : e.getMessage();
                logger.error((String)errorMessage, (Throwable)e);
                registerLocalModelFuture.onFailure((Exception)((Object)new WorkflowStepException((String)errorMessage, ExceptionsHelper.status((Throwable)e))));
            }));
        }
        catch (IllegalArgumentException iae) {
            registerLocalModelFuture.onFailure((Exception)((Object)new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST)));
        }
        catch (FlowFrameworkException e) {
            registerLocalModelFuture.onFailure((Exception)((Object)e));
        }
        return registerLocalModelFuture;
    }

    protected abstract Set<String> getRequiredKeys();

    protected abstract Set<String> getOptionalKeys();
}

