/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.mlt.MoreLikeThis;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;

public class KNearestNeighborClassifier
implements Classifier<BytesRef> {
    protected final MoreLikeThis mlt;
    protected final String[] textFieldNames;
    protected final String classFieldName;
    protected final IndexSearcher indexSearcher;
    protected final int k;
    protected final Query query;

    public KNearestNeighborClassifier(IndexReader indexReader, Similarity similarity, Analyzer analyzer, Query query, int k, int minDocsFreq, int minTermFreq, String classFieldName, String ... textFieldNames) throws IOException {
        this.textFieldNames = textFieldNames;
        this.classFieldName = classFieldName;
        this.mlt = new MoreLikeThis(indexReader);
        this.mlt.setAnalyzer(analyzer);
        this.mlt.setFieldNames(textFieldNames);
        this.indexSearcher = new IndexSearcher(indexReader);
        this.indexSearcher.setSimilarity(Objects.requireNonNullElseGet(similarity, BM25Similarity::new));
        if (minDocsFreq > 0) {
            this.mlt.setMinDocFreq(minDocsFreq);
        }
        if (minTermFreq > 0) {
            this.mlt.setMinTermFreq(minTermFreq);
        }
        this.query = query;
        this.k = k;
    }

    @Override
    public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
        return this.classifyFromTopDocs(this.knnSearch(text));
    }

    protected ClassificationResult<BytesRef> classifyFromTopDocs(TopDocs knnResults) throws IOException {
        List<ClassificationResult<BytesRef>> assignedClasses = this.buildListFromTopDocs(knnResults);
        ClassificationResult<BytesRef> assignedClass = null;
        double maxscore = -1.7976931348623157E308;
        for (ClassificationResult<BytesRef> cl : assignedClasses) {
            if (!(cl.score() > maxscore)) continue;
            assignedClass = cl;
            maxscore = cl.score();
        }
        return assignedClass;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
        TopDocs knnResults = this.knnSearch(text);
        List<ClassificationResult<BytesRef>> assignedClasses = this.buildListFromTopDocs(knnResults);
        Collections.sort(assignedClasses);
        return assignedClasses;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
        TopDocs knnResults = this.knnSearch(text);
        List<ClassificationResult<BytesRef>> assignedClasses = this.buildListFromTopDocs(knnResults);
        Collections.sort(assignedClasses);
        return assignedClasses.subList(0, max);
    }

    private TopDocs knnSearch(String text) throws IOException {
        BooleanQuery.Builder mltQuery = new BooleanQuery.Builder();
        for (String fieldName : this.textFieldNames) {
            String boost = null;
            this.mlt.setBoost(true);
            if (fieldName.contains("^")) {
                String[] field2boost = fieldName.split("\\^");
                fieldName = field2boost[0];
                boost = field2boost[1];
            }
            if (boost != null) {
                this.mlt.setBoostFactor(Float.parseFloat(boost));
            }
            mltQuery.add(new BooleanClause(this.mlt.like(fieldName, new Reader[]{new StringReader(text)}), BooleanClause.Occur.SHOULD));
            this.mlt.setBoostFactor(1.0f);
        }
        WildcardQuery classFieldQuery = new WildcardQuery(new Term(this.classFieldName, "*"));
        mltQuery.add(new BooleanClause((Query)classFieldQuery, BooleanClause.Occur.MUST));
        if (this.query != null) {
            mltQuery.add(this.query, BooleanClause.Occur.MUST);
        }
        return this.indexSearcher.search((Query)mltQuery.build(), this.k);
    }

    protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
        HashMap<BytesRef, Integer> classCounts = new HashMap<BytesRef, Integer>();
        HashMap<BytesRef, Double> classBoosts = new HashMap<BytesRef, Double>();
        float maxScore = topDocs.totalHits.value() == 0L ? Float.NaN : topDocs.scoreDocs[0].score;
        StoredFields storedFields = this.indexSearcher.storedFields();
        for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
            IndexableField[] indexableFieldArray;
            for (IndexableField singleStorableField : indexableFieldArray = storedFields.document(scoreDoc.doc).getFields(this.classFieldName)) {
                if (singleStorableField == null) continue;
                BytesRef cl = new BytesRef((CharSequence)singleStorableField.stringValue());
                classCounts.merge(cl, 1, Integer::sum);
                Double totalBoost = (Double)classBoosts.get(cl);
                double singleBoost = scoreDoc.score / maxScore;
                if (totalBoost != null) {
                    classBoosts.put(cl, totalBoost + singleBoost);
                    continue;
                }
                classBoosts.put(cl, singleBoost);
            }
        }
        ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<ClassificationResult<BytesRef>>();
        ArrayList<ClassificationResult<BytesRef>> temporaryList = new ArrayList<ClassificationResult<BytesRef>>();
        int sumdoc = 0;
        for (Map.Entry entry : classCounts.entrySet()) {
            Integer count = (Integer)entry.getValue();
            Double normBoost = (Double)classBoosts.get(entry.getKey()) / (double)count.intValue();
            temporaryList.add(new ClassificationResult<BytesRef>(((BytesRef)entry.getKey()).clone(), (double)count.intValue() * normBoost / (double)this.k));
            sumdoc += count.intValue();
        }
        if (sumdoc < this.k) {
            for (ClassificationResult classificationResult : temporaryList) {
                returnList.add(new ClassificationResult<BytesRef>((BytesRef)classificationResult.assignedClass(), classificationResult.score() * (double)this.k / (double)sumdoc));
            }
        } else {
            returnList = temporaryList;
        }
        return returnList;
    }

    public String toString() {
        return "KNearestNeighborClassifier{textFieldNames=" + Arrays.toString(this.textFieldNames) + ", classFieldName='" + this.classFieldName + "', k=" + this.k + ", query=" + String.valueOf(this.query) + ", similarity=" + String.valueOf(this.indexSearcher.getSimilarity()) + "}";
    }
}

