/*
 *    WekaDBWrapper.java
 *    Copyright (C) 2011 New Zealand Digital Library, http://www.nzdl.org
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */
package org.greenstone.gsdl3.util;

import java.io.*;
import java.util.Vector;
import java.util.Collections;
import java.util.regex.Pattern;
import java.util.regex.Matcher;

import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;

import org.apache.log4j.*;

import org.greenstone.gsdl3.util.WekaFindInstanceKNN;

/** Java wrapper class for access to the Weka 
 *  Devised (in the first instance) to operate as: java -jar weka.jar <arg1> <arg2> 
 *
 * Inspired by MGSearchWrapper.java
 */

public class WekaDBWrapper
{
    public final static double AV_SEGMENT_LENGTH_SECS = 6.0;
    
    /** the query result, filled in by runQuery */
    protected Vector query_results_;

    protected int offset_ = 100;
    protected int length_ = 20; // **** Unused
    
    // Approximate matching not yet utilized
    protected double radius_;   // **** Unused

    protected int max_docs_;

    protected double arousal_;
    protected double valence_;
    
    static Logger logger = Logger.getLogger (org.greenstone.gsdl3.util.WekaDBWrapper.class.getName ());

    public WekaDBWrapper() {  
	query_results_ = null;
    }

    // query param methods

    /** start point (offset) into the array of feature vectors for a track 
	- 100 by default which equals 10 seconds (assuming 0.1 frame size) */
    public void setOffset(int offset) {
	offset_ = offset;
    }
	
    /** the number of consecutive frames used in match
	- 20 by default which equals 2 seconds (assuming 0.1 frame size) */
    public void setLength(int length) {
	length_ = length;
    }

    /** distance used in approximate matching support - default is 50 */
    public void setRadius(double radius) {
	radius_ = radius;
    }
    
    public void setMaxDocs(int max_docs) {
	max_docs_ = max_docs;
    }

    public void setArousal(double arousal) {
	arousal_ = arousal;
    }
    public void setValence(double valence) {
	valence_ = valence;
    }
    
    /** returns a string with all the current query param settings */
    // the following was in MG version, do we need this in WekaDB version? // ****
    //public String getQueryParams() {}


    protected boolean addQueryResult(boolean first_entry, String doc_id,
				     Vector<Double> arousalVector, Vector<Double> valenceVector,
				     Vector<Double> rankVector, Vector<Integer> offsetVector)
    {

	if (first_entry) {
	    WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(doc_id,arousalVector,valenceVector,rankVector,offsetVector);
	    query_results_.add(wekaDB_doc_info);
	    first_entry = false;
	}
	else {
	    double arousal = arousalVector.get(0);
	    double valence = valenceVector.get(0);

	    double rank    = rankVector.get(0);
	    int offset     = offsetVector.get(0);
	    
	    WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(doc_id,arousal,valence,rank,offset);
	    
	    query_results_.add(wekaDB_doc_info);
	}

	return first_entry;
    }


    protected int mergeResultDoc(Vector query_results, WekaDBDocInfo new_doc_info, double inc_rank_val)
    {
	int merged = 0;

	String new_doc_id = new_doc_info.getDocID();

	final int query_results_len = query_results.size();

	for (int i=0; i<query_results_len; i++) {
	    WekaDBDocInfo existing_doc_info = (WekaDBDocInfo)query_results.get(i);

	    String existing_doc_id = existing_doc_info.getDocID();
	    if (new_doc_id.equals(existing_doc_id)) {
		merged = 1;
		existing_doc_info.incTopRank(inc_rank_val);
		break;
	    }
	}

	if (merged == 0) {
	    query_results.add(new_doc_info);
	}
	
	return merged;
    }
    
    /** actually carry out the query.
	Use the set methods to set query results.
	Writes the result to query_results.
     * - maintains state between requests as can be slow  
     * base_dir and index_path should join together to provide
     * the absolute location of the Weka CSV file e.g.  <col>/index/wekaDB/av-features.csv
     * base_dir must end with a file separator (OS dependant)
     */


    public void runQuery(String wekaDB_index_dir, String knn_model_file, 
			 String assoc_index_dir, String query_string)
    {

	String full_knn_model_filename  = wekaDB_index_dir + File.separatorChar + knn_model_file;

	//System.err.println("**** full knn model filename  = " + full_knn_model_filename);

	// Example returned result from Weka KNN
	// => first line is the input instance ('filename+segment',Arousal,Valence)
	//    following (indented lines) nearest neighbour matches in same format
	//
	// ds_22716_5743-6,-0.549489,-0.118439
	//	ds_22716_5743-6,-0.549489,-0.118439
	//	ds_31008_6550-30,-0.549489,-0.118439
	//	ds_72651_26831-6,-0.549489,-0.118439
	//	ds_26196_9214-18,-0.549489,-0.118439


	WekaFindInstanceKNN.init(full_knn_model_filename);

	String doc_id  = query_string;
	int    segment = offset_;

	String query_doc_id_segment =  doc_id + "-" + segment;

	double query_arousal_val = arousal_;
	double query_valence_val = valence_;

	int k_nearest_num = max_docs_;
	int expanded_k_nearest_num = max_docs_ * 5; // * internally get more matches, then sift through to arrive at the best 'max_docs_'
	
	Pattern doc_seg_re = Pattern.compile("^(\\w+)-(\\d+)$");
	//Matcher query_doc_seg_match = doc_seq_re.matcher(query_doc_id_segment);
	
	Instances nearest_instances
	    = WekaFindInstanceKNN.kNearestNeighbours(query_doc_id_segment,query_arousal_val,query_valence_val,k_nearest_num);

	    
	Vector expanded_query_results = new Vector();

	int nearest_instances_len = nearest_instances.size();
    
	int clamped_expanded_k_nearest_num = Math.max(nearest_instances_len,k_nearest_num);	

	if (clamped_expanded_k_nearest_num > k_nearest_num) {
	    System.err.println("**** expanded number of k-nearest matches = " + clamped_expanded_k_nearest_num);
	}
	
	for (int ei=0; ei<clamped_expanded_k_nearest_num; ei++) {
	    Instance instance = nearest_instances.instance(ei);
	    
	    String matching_doc_id_segment = instance.stringValue(0);

	    //Pattern p = Pattern.compile("^(\\w+)-(\\d+)$");
	    Matcher m = doc_seg_re.matcher(matching_doc_id_segment);
	    if (m.matches()) {

		String matching_doc_id = m.group(1);
		int end_of_matching_segment_offset = Integer.parseInt(m.group(2));
		//int matching_segment_offset = end_of_matching_segment_offset - (int)AV_SEGMENT_LENGTH_SECS;
		int matching_segment_offset = end_of_matching_segment_offset;
		
		if (matching_doc_id.equals(doc_id)) {
		    // don't add in matches that come from a matching segment in the query doc
		    //logger.info("\tSelf-match with query doc => Skipping: " + instance);
		    System.err.println("\tSelf-match with query doc => Skipping: " + instance);

		    continue;
		}

		//logger.info("\tAdding returned instance: " + instance);
		System.err.println("\tAdding returned instance: " + instance);

		double matching_arousal_val = instance.value(1);
		double matching_valence_val = instance.value(2);
		
		double matching_diff = (Math.abs(query_arousal_val - matching_arousal_val)
					+ Math.abs(query_valence_val - matching_valence_val))/4.0;
		double matching_rank = 1.0 - matching_diff;
		
		WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(matching_doc_id,
								  matching_arousal_val, matching_valence_val,
								  matching_rank,matching_segment_offset);
		
		expanded_query_results.add(wekaDB_doc_info);
	    }
	    else {
		logger.error("Returned AV k-nearest neighbour match '"+matching_doc_id_segment+"' could not be parsed as <doc-id>-<segment>" );
	    }		   
	}

	//Collections.sort(expanded_query_results);
		
	query_results_ = new Vector();

	int i = 0;
	while (i < k_nearest_num) {
	    if (i >= expanded_query_results.size()) {
		break;
	    }
	    
	    query_results_.add(expanded_query_results.get(i));
	    i++;
	}

	//Collections.sort(query_results_);	
    }



    public void runQueryDiffAndMerge(String wekaDB_index_dir, String knn_model_file, 
			 String assoc_index_dir, String query_string)
    {

	// combine index_dir with audiodb fileanem

	String full_knn_model_filename  = wekaDB_index_dir + File.separatorChar + knn_model_file;

	//String full_chr12_filename = assoc_index_dir + File.separatorChar 
	//    + query_string + File.separatorChar + "doc.chr12";

	System.err.println("**** full knn model filename  = " + full_knn_model_filename);

	// Example returned result from Weka KNN
	// => first line is the input instance ('filename+segment',Arousal,Valence)
	//    following (indented lines) nearest neighbour matches in same format
	//
	// ds_22716_5743-6,-0.549489,-0.118439
	//	ds_22716_5743-6,-0.549489,-0.118439
	//	ds_31008_6550-30,-0.549489,-0.118439
	//	ds_72651_26831-6,-0.549489,-0.118439
	//	ds_26196_9214-18,-0.549489,-0.118439


	WekaFindInstanceKNN.init(full_knn_model_filename);

	String doc_id  = query_string;
	int    segment = offset_;

	String query_doc_id_segment =  doc_id + "-" + segment;

	double query_arousal_val = arousal_;
	double query_valence_val =  valence_;

	int k_nearest_num = max_docs_;
	int expanded_k_nearest_num = max_docs_ * 5; // * internally get more matches, then sift through to arrive at the best 'max_docs_'
	
	Pattern doc_seg_re = Pattern.compile("^(\\w+)-(\\d+)$");
	//Matcher query_doc_seg_match = doc_seq_re.matcher(query_doc_id_segment);
	
	Instances nearest_instances
	    = WekaFindInstanceKNN.kNearestNeighbours(query_doc_id_segment,query_arousal_val,query_valence_val,k_nearest_num);

	    
	Vector expanded_query_results = new Vector();

	int nearest_instances_len = nearest_instances.size();
    
	int clamped_expanded_k_nearest_num = Math.min(expanded_k_nearest_num,nearest_instances_len);

	double pos_penalty = 0.1;
	int    topup_count = 0;
	
	for (int ei=0; ei<clamped_expanded_k_nearest_num; ei++) {
	    Instance instance = nearest_instances.instance(ei);
	    logger.info("\tProcessing returned instance: " + instance);
	    
	    String matching_doc_id_segment = instance.stringValue(0);

	    //Pattern p = Pattern.compile("^(\\w+)-(\\d+)$");
	    Matcher m = doc_seg_re.matcher(matching_doc_id_segment);
	    if (m.matches()) {

		String matching_doc_id = m.group(1);
		int end_of_matching_segment_offset = Integer.parseInt(m.group(2));
		//int matching_segment_offset = end_of_matching_segment_offset - (int)AV_SEGMENT_LENGTH_SECS;
		int matching_segment_offset = end_of_matching_segment_offset;
		
		if (matching_doc_id.equals(doc_id)) {
		    // don't add in matches that come from a matching segment in the query doc
		    continue;
		}
		
		double matching_arousal_val = instance.value(1);
		double matching_valence_val = instance.value(2);
		
		double matching_diff = (Math.abs(query_arousal_val - matching_arousal_val)
					+ Math.abs(query_valence_val - matching_valence_val))/4.0;
		double matching_rank = 1.0 - matching_diff - (pos_penalty * (double)ei);

		logger.info("\tAdding in: matching_doc_id = " + matching_doc_id);
		WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(matching_doc_id,
								  matching_arousal_val, matching_valence_val,
								  matching_rank,matching_segment_offset);
		
		//expanded_query_results.add(wekaDB_doc_info);

		double inc_rank_val = matching_rank / (double)(topup_count+2); // starts to a 50% (/2) weighting when topup_count == 0
		int merged = mergeResultDoc(expanded_query_results,wekaDB_doc_info,inc_rank_val);

		topup_count += merged;

		if ((expanded_query_results.size() > k_nearest_num) && (topup_count > k_nearest_num)) {
		    // guard to stop multiple recurring matches in the same doc dominationg the rank_val
		    break;
		}
	    }
	    else {
		logger.error("Returned AV k-nearest neighbour match '"+matching_doc_id_segment+"' could not be parsed as <doc-id>-<segment>" );
	    }		   
	}

	Collections.sort(expanded_query_results);
		
	query_results_ = new Vector();

	int i = 0;
	while (i < k_nearest_num) {
	    if (i >= expanded_query_results.size()) {
		break;
	    }
	    
	    query_results_.add(expanded_query_results.get(i));
	    i++;
	}

	//Collections.sort(query_results_);	
    }


    /** get the result out of the wrapper */
    public Vector getQueryResult()
    {
	return query_results_;
    }
}

