Storage refactor for EntityConfidence

Caching the sorted entity list so users don't need to be careful to cache
the result of getEntities (previously dont by TextSelection and
TextClassification, but not TextLink). Also switched to ArrayMap as it's
better suited for small maps like the ones generated by the classifier.

Test: Ran FrameworksCoreTests
Change-Id: I08cc9f72146ccab88b6a3624f3775a366c814f7a
This commit is contained in:
Jan Althaus
2017-11-30 15:01:40 +01:00
parent 05013b3772
commit bbe43dfd97
4 changed files with 43 additions and 52 deletions

View File

@@ -18,13 +18,12 @@ package android.view.textclassifier;
import android.annotation.FloatRange;
import android.annotation.NonNull;
import android.util.ArrayMap;
import com.android.internal.util.Preconditions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -36,42 +35,43 @@ import java.util.Map;
*/
final class EntityConfidence<T> {
private final Map<T, Float> mEntityConfidence = new HashMap<>();
private final Comparator<T> mEntityComparator = (e1, e2) -> {
float score1 = mEntityConfidence.get(e1);
float score2 = mEntityConfidence.get(e2);
if (score1 > score2) {
return -1;
}
if (score1 < score2) {
return 1;
}
return 0;
};
private final ArrayMap<T, Float> mEntityConfidence = new ArrayMap<>();
private final ArrayList<T> mSortedEntities = new ArrayList<>();
EntityConfidence() {}
EntityConfidence(@NonNull EntityConfidence<T> source) {
Preconditions.checkNotNull(source);
mEntityConfidence.putAll(source.mEntityConfidence);
mSortedEntities.addAll(source.mSortedEntities);
}
/**
* Sets an entity type for the classified text and assigns a confidence score.
* Constructs an EntityConfidence from a map of entity to confidence.
*
* @param confidenceScore a value from 0 (low confidence) to 1 (high confidence).
* 0 implies the entity does not exist for the classified text.
* Values greater than 1 are clamped to 1.
* Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1.
*
* @param source a map from entity to a confidence value in the range 0 (low confidence) to
* 1 (high confidence).
*/
public void setEntityType(
@NonNull T type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
Preconditions.checkNotNull(type);
if (confidenceScore > 0) {
mEntityConfidence.put(type, Math.min(1, confidenceScore));
} else {
mEntityConfidence.remove(type);
EntityConfidence(@NonNull Map<T, Float> source) {
Preconditions.checkNotNull(source);
// Prune non-existent entities and clamp to 1.
mEntityConfidence.ensureCapacity(source.size());
for (Map.Entry<T, Float> it : source.entrySet()) {
if (it.getValue() <= 0) continue;
mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
}
// Create a list of entities sorted by decreasing confidence for getEntities().
mSortedEntities.ensureCapacity(mEntityConfidence.size());
mSortedEntities.addAll(mEntityConfidence.keySet());
mSortedEntities.sort((e1, e2) -> {
float score1 = mEntityConfidence.get(e1);
float score2 = mEntityConfidence.get(e2);
return Float.compare(score2, score1);
});
}
/**
@@ -80,10 +80,7 @@ final class EntityConfidence<T> {
*/
@NonNull
public List<T> getEntities() {
List<T> entities = new ArrayList<>(mEntityConfidence.size());
entities.addAll(mEntityConfidence.keySet());
entities.sort(mEntityComparator);
return Collections.unmodifiableList(entities);
return Collections.unmodifiableList(mSortedEntities);
}
/**

View File

@@ -24,6 +24,7 @@ import android.content.Context;
import android.content.Intent;
import android.graphics.drawable.Drawable;
import android.os.LocaleList;
import android.util.ArrayMap;
import android.view.View.OnClickListener;
import android.view.textclassifier.TextClassifier.EntityType;
@@ -32,6 +33,7 @@ import com.android.internal.util.Preconditions;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
/**
* Information for generating a widget to handle classified text.
@@ -95,7 +97,6 @@ public final class TextClassification {
@NonNull private final List<Intent> mIntents;
@NonNull private final List<OnClickListener> mOnClickListeners;
@NonNull private final EntityConfidence<String> mEntityConfidence;
@NonNull private final List<String> mEntities;
private int mLogType;
@NonNull private final String mVersionInfo;
@@ -105,7 +106,7 @@ public final class TextClassification {
@NonNull List<String> labels,
@NonNull List<Intent> intents,
@NonNull List<OnClickListener> onClickListeners,
@NonNull EntityConfidence<String> entityConfidence,
@NonNull Map<String, Float> entityConfidence,
int logType,
@NonNull String versionInfo) {
Preconditions.checkArgument(labels.size() == intents.size());
@@ -117,7 +118,6 @@ public final class TextClassification {
mIntents = intents;
mOnClickListeners = onClickListeners;
mEntityConfidence = new EntityConfidence<>(entityConfidence);
mEntities = mEntityConfidence.getEntities();
mLogType = logType;
mVersionInfo = versionInfo;
}
@@ -135,7 +135,7 @@ public final class TextClassification {
*/
@IntRange(from = 0)
public int getEntityCount() {
return mEntities.size();
return mEntityConfidence.getEntities().size();
}
/**
@@ -147,7 +147,7 @@ public final class TextClassification {
*/
@NonNull
public @EntityType String getEntity(int index) {
return mEntities.get(index);
return mEntityConfidence.getEntities().get(index);
}
/**
@@ -311,8 +311,7 @@ public final class TextClassification {
@NonNull private final List<String> mLabels = new ArrayList<>();
@NonNull private final List<Intent> mIntents = new ArrayList<>();
@NonNull private final List<OnClickListener> mOnClickListeners = new ArrayList<>();
@NonNull private final EntityConfidence<String> mEntityConfidence =
new EntityConfidence<>();
@NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
private int mLogType;
@NonNull private String mVersionInfo = "";
@@ -334,7 +333,7 @@ public final class TextClassification {
public Builder setEntityType(
@NonNull @EntityType String type,
@FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
mEntityConfidence.setEntityType(type, confidenceScore);
mEntityConfidence.put(type, confidenceScore);
return this;
}

View File

@@ -103,11 +103,7 @@ public final class TextLinks {
mOriginalText = originalText;
mStart = start;
mEnd = end;
mEntityScores = new EntityConfidence<>();
for (Map.Entry<String, Float> entry : entityScores.entrySet()) {
mEntityScores.setEntityType(entry.getKey(), entry.getValue());
}
mEntityScores = new EntityConfidence<>(entityScores);
}
/**

View File

@@ -21,12 +21,13 @@ import android.annotation.IntRange;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.os.LocaleList;
import android.util.ArrayMap;
import android.view.textclassifier.TextClassifier.EntityType;
import com.android.internal.util.Preconditions;
import java.util.List;
import java.util.Locale;
import java.util.Map;
/**
* Information about where text selection should be.
@@ -36,7 +37,6 @@ public final class TextSelection {
private final int mStartIndex;
private final int mEndIndex;
@NonNull private final EntityConfidence<String> mEntityConfidence;
@NonNull private final List<String> mEntities;
@NonNull private final String mLogSource;
@NonNull private final String mVersionInfo;
@@ -46,7 +46,6 @@ public final class TextSelection {
mStartIndex = startIndex;
mEndIndex = endIndex;
mEntityConfidence = new EntityConfidence<>(entityConfidence);
mEntities = mEntityConfidence.getEntities();
mLogSource = logSource;
mVersionInfo = versionInfo;
}
@@ -70,7 +69,7 @@ public final class TextSelection {
*/
@IntRange(from = 0)
public int getEntityCount() {
return mEntities.size();
return mEntityConfidence.getEntities().size();
}
/**
@@ -82,7 +81,7 @@ public final class TextSelection {
*/
@NonNull
public @EntityType String getEntity(int index) {
return mEntities.get(index);
return mEntityConfidence.getEntities().get(index);
}
/**
@@ -126,8 +125,7 @@ public final class TextSelection {
private final int mStartIndex;
private final int mEndIndex;
@NonNull private final EntityConfidence<String> mEntityConfidence =
new EntityConfidence<>();
@NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
@NonNull private String mLogSource = "";
@NonNull private String mVersionInfo = "";
@@ -154,7 +152,7 @@ public final class TextSelection {
public Builder setEntityType(
@NonNull @EntityType String type,
@FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
mEntityConfidence.setEntityType(type, confidenceScore);
mEntityConfidence.put(type, confidenceScore);
return this;
}
@@ -181,7 +179,8 @@ public final class TextSelection {
*/
public TextSelection build() {
return new TextSelection(
mStartIndex, mEndIndex, mEntityConfidence, mLogSource, mVersionInfo);
mStartIndex, mEndIndex, new EntityConfidence<>(mEntityConfidence), mLogSource,
mVersionInfo);
}
}