/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.processor;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.DocumentContext;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;
import com.jayway.jsonpath.PathNotFoundException;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.processor.InferenceProcessorAttributes;
import org.opensearch.ml.processor.ModelExecutor;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.transport.client.Client;

public class MLInferenceSearchRequestProcessor
extends AbstractProcessor
implements SearchRequestProcessor,
ModelExecutor {
    private final NamedXContentRegistry xContentRegistry;
    private static final Logger logger = LogManager.getLogger(MLInferenceSearchRequestProcessor.class);
    private final InferenceProcessorAttributes inferenceProcessorAttributes;
    private final boolean ignoreMissing;
    private final String functionName;
    private final List<Map<String, String>> optionalInputMaps;
    private final List<Map<String, String>> optionalOutputMaps;
    private String queryTemplate;
    private final boolean fullResponsePath;
    private final boolean ignoreFailure;
    private final String modelInput;
    private static Client client;
    public static final String TYPE = "ml_inference";
    public static final String IGNORE_MISSING = "ignore_missing";
    public static final String QUERY_TEMPLATE = "query_template";
    public static final String FUNCTION_NAME = "function_name";
    public static final String FULL_RESPONSE_PATH = "full_response_path";
    public static final String MODEL_INPUT = "model_input";
    public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
    public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
    public static final String OPTIONAL_INPUT_MAP = "optional_input_map";
    public static final String OPTIONAL_OUTPUT_MAP = "optional_output_map";

    protected MLInferenceSearchRequestProcessor(String modelId, String queryTemplate, List<Map<String, String>> inputMaps, List<Map<String, String>> outputMaps, List<Map<String, String>> optionalInputMaps, List<Map<String, String>> optionalOutputMaps, Map<String, String> modelConfigMaps, int maxPredictionTask, String tag, String description, boolean ignoreMissing, String functionName, boolean fullResponsePath, boolean ignoreFailure, String modelInput, Client client, NamedXContentRegistry xContentRegistry) {
        super(tag, description, ignoreFailure);
        this.inferenceProcessorAttributes = new InferenceProcessorAttributes(modelId, inputMaps, outputMaps, modelConfigMaps, maxPredictionTask);
        this.optionalInputMaps = optionalInputMaps;
        this.optionalOutputMaps = optionalOutputMaps;
        this.ignoreMissing = ignoreMissing;
        this.functionName = functionName;
        this.fullResponsePath = fullResponsePath;
        this.queryTemplate = queryTemplate;
        this.ignoreFailure = ignoreFailure;
        this.modelInput = modelInput;
        MLInferenceSearchRequestProcessor.client = client;
        this.xContentRegistry = xContentRegistry;
    }

    public SearchRequest processRequest(SearchRequest request) throws Exception {
        throw new RuntimeException("ML inference search request processor make asynchronous calls and does not call processRequest");
    }

    public void processRequestAsync(SearchRequest request, PipelineProcessingContext requestContext, ActionListener<SearchRequest> requestListener) {
        try {
            if (request.source() == null) {
                throw new IllegalArgumentException("query body is empty, cannot processor inference on empty query request.");
            }
            this.setRequestContextFromExt(request, requestContext);
            String queryString = request.source().toString();
            this.rewriteQueryString(request, queryString, requestListener, requestContext);
        }
        catch (Exception e) {
            if (this.ignoreFailure) {
                requestListener.onResponse((Object)request);
            }
            requestListener.onFailure(e);
        }
    }

    private void rewriteQueryString(SearchRequest request, String queryString, ActionListener<SearchRequest> requestListener, PipelineProcessingContext requestContext) throws IOException {
        int combinedInputMapSize;
        List<Map<String, String>> processInputMap = this.inferenceProcessorAttributes.getInputMaps();
        List<Map<String, String>> processOutputMap = this.inferenceProcessorAttributes.getOutputMaps();
        List<Map<String, String>> combinedInputMaps = ModelExecutor.combineMaps(processInputMap, this.optionalInputMaps);
        List<Map<String, String>> combinedOutputMaps = ModelExecutor.combineMaps(processOutputMap, this.optionalOutputMaps);
        int n = combinedInputMapSize = combinedInputMaps != null ? combinedInputMaps.size() : 0;
        if (combinedInputMapSize == 0) {
            requestListener.onResponse((Object)request);
            return;
        }
        try {
            if (!this.validateQueryFieldInQueryString(processInputMap, processOutputMap, queryString, this.ignoreMissing)) {
                requestListener.onResponse((Object)request);
            }
        }
        catch (Exception e) {
            if (this.ignoreFailure) {
                requestListener.onResponse((Object)request);
                return;
            }
            requestListener.onFailure(e);
            return;
        }
        ActionListener<Map<Integer, MLOutput>> rewriteRequestListener = this.createRewriteRequestListener(request, queryString, requestListener, combinedOutputMaps, requestContext);
        GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = this.createBatchPredictionListener(rewriteRequestListener, combinedInputMapSize);
        for (int inputMapIndex = 0; inputMapIndex < combinedInputMapSize; ++inputMapIndex) {
            this.processPredictions(queryString, combinedInputMaps, inputMapIndex, batchPredictionListener);
        }
    }

    private ActionListener<Map<Integer, MLOutput>> createRewriteRequestListener(final SearchRequest request, final String queryString, final ActionListener<SearchRequest> requestListener, final List<Map<String, String>> processOutputMap, final PipelineProcessingContext requestContext) {
        return new ActionListener<Map<Integer, MLOutput>>(){

            public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
                for (Map.Entry<Integer, MLOutput> entry : multipleMLOutputs.entrySet()) {
                    Integer mappingIndex = entry.getKey();
                    MLOutput mlOutput = entry.getValue();
                    Map outputMapping = (Map)processOutputMap.get(mappingIndex);
                    try {
                        SearchSourceBuilder searchSourceBuilder;
                        if (MLInferenceSearchRequestProcessor.this.queryTemplate == null) {
                            Object incomeQueryObject = JsonPath.parse((String)queryString).read("$", new Predicate[0]);
                            this.updateIncomeQueryObject(incomeQueryObject, outputMapping, mlOutput, requestContext);
                            searchSourceBuilder = MLInferenceSearchRequestProcessor.getSearchSourceBuilder(MLInferenceSearchRequestProcessor.this.xContentRegistry, StringUtils.toJson((Object)incomeQueryObject));
                            request.source(searchSourceBuilder);
                            requestListener.onResponse((Object)request);
                            continue;
                        }
                        String newQueryString = this.updateQueryTemplate(MLInferenceSearchRequestProcessor.this.queryTemplate, outputMapping, mlOutput);
                        searchSourceBuilder = MLInferenceSearchRequestProcessor.getSearchSourceBuilder(MLInferenceSearchRequestProcessor.this.xContentRegistry, newQueryString);
                        request.source(searchSourceBuilder);
                        requestListener.onResponse((Object)request);
                    }
                    catch (Exception e) {
                        if (MLInferenceSearchRequestProcessor.this.ignoreFailure) {
                            logger.error("Failed in writing prediction outcomes to new query", (Throwable)e);
                            requestListener.onResponse((Object)request);
                            continue;
                        }
                        requestListener.onFailure(e);
                    }
                }
            }

            public void onFailure(Exception e) {
                if (MLInferenceSearchRequestProcessor.this.ignoreFailure) {
                    logger.error("Failed in writing prediction outcomes to new query", (Throwable)e);
                    requestListener.onResponse((Object)request);
                } else {
                    requestListener.onFailure(e);
                }
            }

            private void updateIncomeQueryObject(Object incomeQueryObject, Map<String, String> outputMapping, MLOutput mlOutput, PipelineProcessingContext requestContext2) {
                for (Map.Entry<String, String> outputMapEntry : outputMapping.entrySet()) {
                    String newQueryField = null;
                    try {
                        newQueryField = outputMapEntry.getKey();
                        String modelOutputFieldName = outputMapEntry.getValue();
                        Object modelOutputValue = MLInferenceSearchRequestProcessor.this.getModelOutputValue(mlOutput, modelOutputFieldName, MLInferenceSearchRequestProcessor.this.ignoreMissing, MLInferenceSearchRequestProcessor.this.fullResponsePath);
                        requestContext2.setAttribute(newQueryField, modelOutputValue);
                        if (newQueryField.startsWith("$.ext.") || newQueryField.startsWith("ext.")) {
                            incomeQueryObject = StringUtils.prepareNestedStructures((Object)incomeQueryObject, (String)newQueryField);
                        }
                        if (!StringUtils.pathExists((Object)incomeQueryObject, (String)newQueryField)) continue;
                        JsonPath.using((Configuration)ModelExecutor.suppressExceptionConfiguration).parse(incomeQueryObject).set(newQueryField, modelOutputValue, new Predicate[0]);
                    }
                    catch (PathNotFoundException e) {
                        logger.error("Failed to set {} in query string: {}", (Object)newQueryField, (Object)e.getMessage(), (Object)e);
                        throw new IllegalArgumentException("can not find path " + newQueryField + "in query string");
                    }
                }
            }

            private String updateQueryTemplate(String queryTemplate, Map<String, String> outputMapping, MLOutput mlOutput) {
                HashMap<String, Object> valuesMap = new HashMap<String, Object>();
                for (Map.Entry<String, String> outputMapEntry : outputMapping.entrySet()) {
                    String newQueryField = outputMapEntry.getKey();
                    String modelOutputFieldName = outputMapEntry.getValue();
                    Object modelOutputValue = MLInferenceSearchRequestProcessor.this.getModelOutputValue(mlOutput, modelOutputFieldName, MLInferenceSearchRequestProcessor.this.ignoreMissing, MLInferenceSearchRequestProcessor.this.fullResponsePath);
                    if (modelOutputValue instanceof Map) {
                        modelOutputValue = StringUtils.toJson((Object)modelOutputValue);
                    }
                    valuesMap.put(newQueryField, modelOutputValue);
                }
                StringSubstitutor sub = new StringSubstitutor(valuesMap);
                return sub.replace(queryTemplate);
            }
        };
    }

    private GroupedActionListener<Map<Integer, MLOutput>> createBatchPredictionListener(final ActionListener<Map<Integer, MLOutput>> rewriteRequestListener, int inputMapSize) {
        return new GroupedActionListener((ActionListener)new ActionListener<Collection<Map<Integer, MLOutput>>>(this){

            public void onResponse(Collection<Map<Integer, MLOutput>> mlOutputMapCollection) {
                HashMap<Integer, MLOutput> mlOutputMaps = new HashMap<Integer, MLOutput>();
                for (Map<Integer, MLOutput> mlOutputMap : mlOutputMapCollection) {
                    mlOutputMaps.putAll(mlOutputMap);
                }
                rewriteRequestListener.onResponse(mlOutputMaps);
            }

            public void onFailure(Exception e) {
                logger.error("Prediction Failed:", (Throwable)e);
                rewriteRequestListener.onFailure(e);
            }
        }, Math.max(inputMapSize, 1));
    }

    private boolean validateRequiredInputMappingFields(List<Map<String, String>> processInputMap, String queryString, boolean ignoreMissing) {
        Configuration suppressExceptionConfiguration = Configuration.defaultConfiguration().addOptions(new Option[]{Option.SUPPRESS_EXCEPTIONS});
        DocumentContext jsonData = JsonPath.using((Configuration)suppressExceptionConfiguration).parse(queryString);
        for (Map<String, String> inputMap : processInputMap) {
            for (Map.Entry<String, String> entry : inputMap.entrySet()) {
                String queryField = entry.getValue();
                Object pathData = jsonData.read(queryField, new Predicate[0]);
                if (pathData != null) continue;
                if (!ignoreMissing) {
                    throw new IllegalArgumentException("cannot find field: " + queryField + " in query string: " + jsonData.jsonString());
                }
                return false;
            }
        }
        return true;
    }

    private boolean validateRequiredOutputMappingFields(List<Map<String, String>> processOutputMap, String queryString, boolean ignoreMissing) {
        Configuration suppressExceptionConfiguration = Configuration.defaultConfiguration().addOptions(new Option[]{Option.SUPPRESS_EXCEPTIONS});
        DocumentContext jsonData = JsonPath.using((Configuration)suppressExceptionConfiguration).parse(queryString);
        if (this.queryTemplate == null) {
            for (Map<String, String> outputMap : processOutputMap) {
                for (Map.Entry<String, String> entry : outputMap.entrySet()) {
                    Object pathData;
                    String queryField = entry.getKey();
                    if (!queryField.startsWith("query.") && !queryField.startsWith("$.query.") || (pathData = jsonData.read(queryField, new Predicate[0])) != null) continue;
                    if (!ignoreMissing) {
                        throw new IllegalArgumentException("cannot find field: " + queryField + " in query string: " + jsonData.jsonString());
                    }
                    return false;
                }
            }
        }
        return true;
    }

    private boolean validateQueryFieldInQueryString(List<Map<String, String>> processInputMap, List<Map<String, String>> processOutputMap, String queryString, boolean ignoreMissing) {
        if (!CollectionUtils.isEmpty(processInputMap) && !this.validateRequiredInputMappingFields(processInputMap, queryString, ignoreMissing)) {
            return false;
        }
        return CollectionUtils.isEmpty(processOutputMap) || this.validateRequiredOutputMappingFields(processOutputMap, queryString, ignoreMissing);
    }

    private void processPredictions(String queryString, List<Map<String, String>> processInputMap, final int inputMapIndex, final GroupedActionListener batchPredictionListener) throws IOException {
        HashMap<String, String> modelParameters = new HashMap<String, String>();
        HashMap<String, String> modelConfigs = new HashMap<String, String>();
        if (this.inferenceProcessorAttributes.getModelConfigMaps() != null) {
            modelParameters.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
            modelConfigs.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
        }
        Map<Object, Object> inputMapping = new HashMap();
        if (processInputMap != null) {
            inputMapping = processInputMap.get(inputMapIndex);
            Object newQuery = JsonPath.parse((String)queryString).read("$", new Predicate[0]);
            for (Map.Entry<Object, Object> entry : inputMapping.entrySet()) {
                String modelInputFieldName = (String)entry.getKey();
                String queryFieldName = (String)entry.getValue();
                if (!this.hasField(newQuery, queryFieldName)) continue;
                String queryFieldValue = StringUtils.toJson((Object)JsonPath.parse((Object)newQuery).read(queryFieldName, new Predicate[0]));
                modelParameters.put(modelInputFieldName, queryFieldValue);
            }
        }
        HashSet inputMapKeys = new HashSet(modelParameters.keySet());
        inputMapKeys.removeAll(modelConfigs.keySet());
        HashMap<String, String> inputMappings = new HashMap<String, String>();
        for (String k : inputMapKeys) {
            inputMappings.put(k, (String)modelParameters.get(k));
        }
        ActionRequest actionRequest = this.getMLModelInferenceRequest(this.xContentRegistry, modelParameters, modelConfigs, inputMappings, this.inferenceProcessorAttributes.getModelId(), this.functionName, this.modelInput);
        client.execute((ActionType)MLPredictionTaskAction.INSTANCE, actionRequest, (ActionListener)new ActionListener<MLTaskResponse>(this){

            public void onResponse(MLTaskResponse mlTaskResponse) {
                MLOutput mlOutput = mlTaskResponse.getOutput();
                HashMap<Integer, MLOutput> mlOutputMap = new HashMap<Integer, MLOutput>();
                mlOutputMap.put(inputMapIndex, mlOutput);
                batchPredictionListener.onResponse(mlOutputMap);
            }

            public void onFailure(Exception e) {
                batchPredictionListener.onFailure(e);
            }
        });
    }

    private static SearchSourceBuilder getSearchSourceBuilder(NamedXContentRegistry xContentRegistry, String queryString) throws IOException {
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, queryString);
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)queryParser.nextToken(), (XContentParser)queryParser);
        searchSourceBuilder.parseXContent(queryParser);
        return searchSourceBuilder;
    }

    public String getType() {
        return TYPE;
    }

    @Generated
    public InferenceProcessorAttributes getInferenceProcessorAttributes() {
        return this.inferenceProcessorAttributes;
    }

    @Generated
    public List<Map<String, String>> getOptionalInputMaps() {
        return this.optionalInputMaps;
    }

    @Generated
    public List<Map<String, String>> getOptionalOutputMaps() {
        return this.optionalOutputMaps;
    }

    public static class Factory
    implements Processor.Factory<SearchRequestProcessor> {
        private final Client client;
        private final NamedXContentRegistry xContentRegistry;

        public Factory(Client client, NamedXContentRegistry xContentRegistry) {
            this.client = client;
            this.xContentRegistry = xContentRegistry;
        }

        public MLInferenceSearchRequestProcessor create(Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories, String processorTag, String description, boolean ignoreFailure, Map<String, Object> config, Processor.PipelineContext pipelineContext) {
            String modelId = ConfigurationUtils.readStringProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)"model_id");
            String queryTemplate = ConfigurationUtils.readOptionalStringProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.QUERY_TEMPLATE);
            Map modelConfigInput = ConfigurationUtils.readOptionalMap((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)"model_config");
            List inputMaps = ConfigurationUtils.readOptionalList((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)"input_map");
            List outputMaps = ConfigurationUtils.readOptionalList((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)"output_map");
            List optionalInputMaps = ConfigurationUtils.readOptionalList((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.OPTIONAL_INPUT_MAP);
            List optionalOutputMaps = ConfigurationUtils.readOptionalList((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.OPTIONAL_OUTPUT_MAP);
            if (CollectionUtils.isEmpty((Collection)inputMaps) && CollectionUtils.isEmpty((Collection)optionalInputMaps)) {
                throw new IllegalArgumentException("Please provide at least one non-empty input_map or optional_input_map for ML Inference Search Request Processor");
            }
            if (CollectionUtils.isEmpty((Collection)outputMaps) && CollectionUtils.isEmpty((Collection)optionalOutputMaps)) {
                throw new IllegalArgumentException("Please provide at least one non-empty output_map or optional_output_map for ML Inference Search Request Processor");
            }
            int maxPredictionTask = ConfigurationUtils.readIntProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)"max_prediction_tasks", (Integer)10);
            boolean ignoreMissing = ConfigurationUtils.readBooleanProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.IGNORE_MISSING, (boolean)false);
            String functionName = ConfigurationUtils.readStringProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.FUNCTION_NAME, (String)FunctionName.REMOTE.name());
            String modelInput = ConfigurationUtils.readOptionalStringProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.MODEL_INPUT);
            if (functionName.equalsIgnoreCase("remote")) {
                modelInput = modelInput != null ? modelInput : MLInferenceSearchRequestProcessor.DEFAULT_MODEl_INPUT;
            } else if (modelInput == null) {
                throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor");
            }
            boolean defaultFullResponsePath = !functionName.equalsIgnoreCase(FunctionName.REMOTE.name());
            boolean fullResponsePath = ConfigurationUtils.readBooleanProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.FULL_RESPONSE_PATH, (boolean)defaultFullResponsePath);
            Map modelConfigMaps = null;
            if (modelConfigInput != null) {
                modelConfigMaps = StringUtils.getParameterMap((Map)modelConfigInput);
            }
            List<Map<String, String>> combinedInputMaps = ModelExecutor.combineMaps(inputMaps, optionalInputMaps);
            List<Map<String, String>> combinedOutputMaps = ModelExecutor.combineMaps(outputMaps, optionalOutputMaps);
            if (combinedInputMaps != null && combinedInputMaps.size() > maxPredictionTask) {
                throw new IllegalArgumentException("The number of prediction task setting in this process is " + combinedInputMaps.size() + ". It exceeds the max_prediction_tasks of " + maxPredictionTask + ". Please reduce the size of input_map or optional_input_map or increase max_prediction_tasks.");
            }
            if (combinedOutputMaps != null && combinedInputMaps != null && combinedOutputMaps.size() != combinedInputMaps.size()) {
                throw new IllegalArgumentException("when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of " + combinedInputMaps.size() + ", while output_maps is in the length of " + combinedOutputMaps.size() + ". Please adjust mappings.");
            }
            return new MLInferenceSearchRequestProcessor(modelId, queryTemplate, inputMaps, outputMaps, optionalInputMaps, optionalOutputMaps, modelConfigMaps, maxPredictionTask, processorTag, description, ignoreMissing, functionName, fullResponsePath, ignoreFailure, modelInput, this.client, this.xContentRegistry);
        }
    }
}

