Skip to content

Commit

Permalink
added support for custom weighting scheme per field
Browse files Browse the repository at this point in the history
  • Loading branch information
jnioche committed Aug 26, 2010
1 parent fc0ab27 commit e266857
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 21 deletions.
3 changes: 3 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
1.4.1-SNAPSHOT
- can specify weighting schemes per field
- added TestWeightingSchemes
1.4
- added XMLCorpusReader and XMLCorpusClassifier which generate a raw file from a XML corpus or classify a XML corpus using an existing model
- added java implementation of liblinear
Expand Down
2 changes: 2 additions & 0 deletions src/com/digitalpebble/classification/Field.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.digitalpebble.classification;

/** Field name and its content **/

public class Field {

String _name = null;
Expand Down
52 changes: 47 additions & 5 deletions src/com/digitalpebble/classification/Lexicon.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ public class Lexicon {
/** list of fields used by a corpus * */
private Map<String, Integer> fields = new HashMap<String, Integer>();

/** Custom weighting schemes for fields **/
private Map<String, WeightingMethod> customWeights = new HashMap<String, WeightingMethod>();

private int lastFieldId = -1;

private AttributeScorer filter;
Expand All @@ -83,14 +86,40 @@ public Lexicon(String file) throws IOException {
this.loadFromFile(file);
}

/**
* Returns the weighting scheme used for a specific field or the default one
* if nothing has been specified for it
**/
public WeightingMethod getMethod(String fieldName) {
WeightingMethod method = this.customWeights.get(fieldName);
if (method != null)
return method;
return this.method_used;
}

/** Returns the default weighting scheme **/
public WeightingMethod getMethod() {
return this.method_used;
}

/** Sets the default weighting scheme **/
public void setMethod(WeightingMethod method) {
this.method_used = method;
}

/** Sets the weighting scheme for a specific field **/
public void setMethod(WeightingMethod method, String fieldName) {
WeightingMethod existingmethod = this.customWeights.get(fieldName);
if (existingmethod == null) {
this.customWeights.put(fieldName, method);
return;
}
// already one specified : check that it is the same as the one we have
if (!method.equals(existingmethod))
throw new RuntimeException("Already set weight of field "
+ fieldName + " to " + existingmethod.toString());
}

public int getDocNum() {
return this.docNum;
}
Expand Down Expand Up @@ -214,7 +243,7 @@ public void applyAttributeFilter(AttributeScorer filter, int rank) {

// creates an entry for the token
// called from Document
public int createIndex(String tokenForm) {
public int createIndex(String tokenForm) {
int[] index = (int[]) tokenForm2index.get(tokenForm);
if (index == null) {
index = new int[] { nextAttributeID };
Expand Down Expand Up @@ -245,13 +274,21 @@ private void loadFromFile(String filename) throws IOException {
this.labels = Arrays.asList(reader.readLine().split(" "));
String[] tmp = reader.readLine().split(" ");
for (String f : tmp) {
getFieldID(f, true);
// see if there is a custom weight for it
String[] fieldTokens = f.split(":");
String field_name = fieldTokens[0];
if (fieldTokens.length > 1) {
WeightingMethod method = Parameters.WeightingMethod
.methodFromString(fieldTokens[1]);
customWeights.put(field_name, method);
}
getFieldID(field_name, true);
}
int loaded = 0;
int highestID = 0;
Pattern tab = Pattern.compile("\t");
while ((line = reader.readLine()) != null) {
String[] content_pos = tab.split(line);
String[] content_pos = tab.split(line);
int index = Integer.parseInt(content_pos[1]);
if (index > highestID)
highestID = index;
Expand Down Expand Up @@ -284,11 +321,16 @@ public void saveToFile(String filename) throws IOException {
writer.write((String) labelIters.next() + " ");
}
writer.write("\n");
// save the field names (if available)
// save the field names (possibly with non default scheme)
for (String fname : this.getFields()) {
writer.write(fname + " ");
writer.write(fname);
WeightingMethod method = customWeights.get(fname);
if (method != null)
writer.write(":" + method.name());
writer.write(" ");
}
writer.write("\n");

// dump all token_forms one by one
while (forms.hasNext()) {
String key = (String) forms.next();
Expand Down
17 changes: 13 additions & 4 deletions src/com/digitalpebble/classification/MultiFieldDocument.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import java.util.Iterator;
import java.util.TreeMap;

import com.digitalpebble.classification.Parameters.WeightingMethod;

public class MultiFieldDocument implements Document {
int label = 0;

Expand Down Expand Up @@ -227,8 +229,8 @@ public Vector getFeatureVector(Lexicon lexicon,
if (indices[pos] == Integer.MAX_VALUE) {
break;
}
if (lexicon.getDocFreq(indices[pos]) <= 0) continue;
double score = getScore(pos, lexicon, method, numDocs);
if (lexicon.getDocFreq(indices[pos]) <= 0) continue;
double score = getScore(pos, lexicon, numDocs);
// removed in meantime?
if (score == 0) continue;
copyvalues[pos] = score;
Expand All @@ -246,15 +248,22 @@ public Vector getFeatureVector(Lexicon lexicon,
return new Vector(trimmedindices, trimmedvalues);
}

private double getScore(int pos, Lexicon lexicon,
Parameters.WeightingMethod method, double numdocs) {
/**
* Returns the score of an attribute given the weighting scheme
* specified in the lexicon or for a specific field
**/
private double getScore(int pos, Lexicon lexicon, double numdocs) {
double score = 0;
int indexTerm = this.indices[pos];
double occurences = (double) this.freqs[pos];

int fieldNum = this.indexToField[pos];
double frequency = occurences / tokensPerField[fieldNum];

// is there a custom weight for this field?
String fieldName = lexicon.getFields()[fieldNum];
WeightingMethod method = lexicon.getMethod(fieldName);

if (method.equals(Parameters.WeightingMethod.BOOLEAN)) {
score = 1;
} else if (method.equals(Parameters.WeightingMethod.OCCURRENCES)) {
Expand Down
57 changes: 45 additions & 12 deletions src/com/digitalpebble/classification/test/TestMultiFieldDocs.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,53 @@

package com.digitalpebble.classification.test;

import java.util.Map;

import com.digitalpebble.classification.Document;
import com.digitalpebble.classification.Field;
import com.digitalpebble.classification.Parameters;
import com.digitalpebble.classification.RAMTrainingCorpus;
import com.digitalpebble.classification.TextClassifier;
import com.digitalpebble.classification.Vector;

public class TestMultiFieldDocs extends AbstractLearnerTest {

public void testMultiField() throws Exception {
Field[] fields = new Field[3];
fields[0] = new Field("title", new String[]{"This","is","a","title"});
fields[1] = new Field("abstract", new String[]{"abstract"});
fields[2] = new Field("content", new String[]{"This","is","the","content","this","will","have","a","large","value"});
fields[0] = new Field("title", new String[] { "This", "is", "a",
"title" });
fields[1] = new Field("abstract", new String[] { "abstract" });
fields[2] = new Field("content", new String[] { "This", "is", "the",
"content", "this", "will", "have", "a", "large", "value" });
learner.setMethod(Parameters.WeightingMethod.TFIDF);
Document doc = learner.createDocument(fields, "large");

Field[] fields2 = new Field[2];
fields2[0] = new Field("title", new String[]{"This","is","not","a","title"});
fields2[0] = new Field("title", new String[] { "This", "is", "not",
"a", "title" });
// fields2[1] = new Field("abstract", new String[]{"abstract"});
fields2[1] = new Field("content", new String[]{"This","is","the","content","this","will","have","a","small","value"});
fields2[1] = new Field("content", new String[] { "This", "is", "the",
"content", "this", "will", "have", "a", "small", "value" });
learner.setMethod(Parameters.WeightingMethod.TFIDF);
Document doc2 = learner.createDocument(fields2, "small");

// try putting the same field several times
Field[] fields3 = new Field[3];
fields3[0] = new Field("title", new String[]{"This","is","not","a","title"});
fields3[0] = new Field("title", new String[] { "This", "is", "not",
"a", "title" });
// fields2[1] = new Field("abstract", new String[]{"abstract"});
fields3[1] = new Field("content", new String[]{"This","is","the","content","this","will","have","a","small","value"});
fields3[2] = new Field("title", new String[]{"some","different","content"});
fields3[1] = new Field("content", new String[] { "This", "is", "the",
"content", "this", "will", "have", "a", "small", "value" });
fields3[2] = new Field("title", new String[] { "some", "different",
"content" });
learner.setMethod(Parameters.WeightingMethod.TFIDF);
Document doc3 = learner.createDocument(fields3, "small");

RAMTrainingCorpus corpus = new RAMTrainingCorpus();
corpus.add(doc);
corpus.add(doc2);
learner.learn(corpus);

TextClassifier classi = TextClassifier.getClassifier(tempFile);
double[] scores = classi.classify(doc);
assertEquals("large", classi.getBestLabel(scores));
Expand All @@ -62,4 +72,27 @@ public void testMultiField() throws Exception {
assertEquals("small", classi.getBestLabel(scores));
}

public void testCustomWeightingScheme() throws Exception {
Field[] fields = new Field[1];
fields[0] = new Field("keywords", new String[] { "test","keywords"});
learner.setMethod(Parameters.WeightingMethod.FREQUENCY);
learner.getLexicon().setMethod(Parameters.WeightingMethod.BOOLEAN, "keywords");
Document doc = learner.createDocument(fields, "large");
Vector vector = doc.getFeatureVector(learner.getLexicon());

// check that the values for the field keywords are boolean
int[] indices = vector.getIndices();
double[] values = vector.getValues();

Map<Integer, String> invertedIndex = learner.getLexicon()
.getInvertedIndex();

for (int i = 0; i < indices.length; i++) {
// retrieve the corresponding entry in the lexicon
String label = invertedIndex.get(indices[i]);
double expected = 1.0;
assertEquals("label: "+label,expected, values[i]);
}
}

}

0 comments on commit e266857

Please sign in to comment.