/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.plugin.transport;

import java.io.IOException;
import java.util.concurrent.ExecutionException;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.action.ActionListener;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.knn.plugin.transport.TrainingModelRequest;
import org.opensearch.knn.plugin.transport.TrainingModelResponse;
import org.opensearch.knn.training.TrainingJob;
import org.opensearch.knn.training.TrainingJobRunner;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class TrainingModelTransportAction
extends HandledTransportAction<TrainingModelRequest, TrainingModelResponse> {
    private final ClusterService clusterService;

    @Inject
    public TrainingModelTransportAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService) {
        super("cluster:admin/knn_training_model_action", transportService, actionFilters, TrainingModelRequest::new);
        this.clusterService = clusterService;
    }

    protected void doExecute(Task task, TrainingModelRequest request, ActionListener<TrainingModelResponse> listener) {
        NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext(request.getTrainingDataSizeInKB(), request.getTrainingIndex(), request.getTrainingField(), NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(), this.clusterService, request.getMaximumVectorCount(), request.getSearchSize(), request.getVectorDataType());
        NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext = new NativeMemoryEntryContext.AnonymousEntryContext(request.getKnnMethodContext().estimateOverheadInKB(request.getDimension()), NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance());
        TrainingJob trainingJob = new TrainingJob(request.getModelId(), request.getKnnMethodContext(), NativeMemoryCacheManager.getInstance(), trainingDataEntryContext, modelAnonymousEntryContext, request.getDimension(), request.getDescription(), this.clusterService.localNode().getEphemeralId(), request.getVectorDataType());
        KNNCounter.TRAINING_REQUESTS.increment();
        ActionListener wrappedListener = ActionListener.wrap(arg_0 -> listener.onResponse(arg_0), ex -> {
            KNNCounter.TRAINING_ERRORS.increment();
            listener.onFailure(ex);
        });
        try {
            TrainingJobRunner.getInstance().execute(trainingJob, (ActionListener<IndexResponse>)ActionListener.wrap(indexResponse -> wrappedListener.onResponse((Object)new TrainingModelResponse(indexResponse.getId())), arg_0 -> ((ActionListener)wrappedListener).onFailure(arg_0)));
        }
        catch (IOException | InterruptedException | ExecutionException e) {
            wrappedListener.onFailure(e);
        }
    }
}

