Merge "Follow-up CL of ag/6935284, add entities to extras in generateLinks" into qt-dev

This commit is contained in:
TreeHugger Robot
2019-04-17 10:11:18 +00:00
committed by Android (Google) Code Review
3 changed files with 64 additions and 2 deletions

View File

@@ -36,6 +36,7 @@ import java.util.List;
// TODO: Make this a TestApi for CTS testing.
public final class ExtrasUtils {
// Keys for response objects.
private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
private static final String ENTITIES_EXTRAS = "entities-extras";
private static final String ACTION_INTENT = "action-intent";
@@ -48,6 +49,10 @@ public final class ExtrasUtils {
private static final String TEXT_LANGUAGES = "text-languages";
private static final String ENTITIES = "entities";
// Keys for request objects.
private static final String IS_SERIALIZED_ENTITY_DATA_ENABLED =
"is-serialized-entity-data-enabled";
private ExtrasUtils() {}
/**
@@ -308,7 +313,23 @@ public final class ExtrasUtils {
/**
* Returns a list of entities contained in the {@code extra}.
*/
@Nullable
public static List<Bundle> getEntities(Bundle container) {
return container.getParcelableArrayList(ENTITIES);
}
/**
* Whether the annotator should populate serialized entity data into the result object.
*/
public static boolean isSerializedEntityDataEnabled(TextLinks.Request request) {
return request.getExtras().getBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED);
}
/**
* To indicate whether the annotator should populate serialized entity data in the result
* object.
*/
public static void putIsSerializedEntityDataEnabled(Bundle bundle, boolean isEnabled) {
bundle.putBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED, isEnabled);
}
}

View File

@@ -307,6 +307,8 @@ public final class TextClassifierImpl implements TextClassifier {
final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
final AnnotatorModel annotatorImpl =
getAnnotatorImpl(request.getDefaultLocales());
final boolean isSerializedEntityDataEnabled =
ExtrasUtils.isSerializedEntityDataEnabled(request);
final AnnotatorModel.AnnotatedSpan[] annotations =
annotatorImpl.annotate(
textString,
@@ -314,7 +316,10 @@ public final class TextClassifierImpl implements TextClassifier {
refTime.toInstant().toEpochMilli(),
refTime.getZone().getId(),
localesString,
detectLanguageTags));
detectLanguageTags,
entitiesToIdentify,
AnnotatorModel.AnnotationUsecase.SMART.getValue(),
isSerializedEntityDataEnabled));
for (AnnotatorModel.AnnotatedSpan span : annotations) {
final AnnotatorModel.ClassificationResult[] results =
span.getClassification();
@@ -326,7 +331,11 @@ public final class TextClassifierImpl implements TextClassifier {
for (int i = 0; i < results.length; i++) {
entityScores.put(results[i].getCollection(), results[i].getScore());
}
builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores);
Bundle extras = new Bundle();
if (isSerializedEntityDataEnabled) {
ExtrasUtils.putEntities(extras, results);
}
builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
}
final TextLinks links = builder.build();
final long endTimeMs = System.currentTimeMillis();

View File

@@ -361,6 +361,38 @@ public class TextClassifierTest {
mClassifier.generateLinks(request);
}
@Test
public void testGenerateLinks_entityData() {
if (isTextClassifierDisabled()) return;
String text = "The number is +12122537077.";
Bundle extras = new Bundle();
ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
TextLinks textLinks = mClassifier.generateLinks(request);
Truth.assertThat(textLinks.getLinks()).hasSize(1);
TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
Truth.assertThat(entities).hasSize(1);
Bundle entity = entities.get(0);
Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE);
}
@Test
public void testGenerateLinks_entityData_disabled() {
if (isTextClassifierDisabled()) return;
String text = "The number is +12122537077.";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
TextLinks textLinks = mClassifier.generateLinks(request);
Truth.assertThat(textLinks.getLinks()).hasSize(1);
TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
Truth.assertThat(entities).isNull();
}
@Test
public void testDetectLanguage() {
if (isTextClassifierDisabled()) return;