diff --git a/core/java/android/view/textclassifier/ExtrasUtils.java b/core/java/android/view/textclassifier/ExtrasUtils.java index 7b236747bae67..11e0e2ca072c0 100644 --- a/core/java/android/view/textclassifier/ExtrasUtils.java +++ b/core/java/android/view/textclassifier/ExtrasUtils.java @@ -36,6 +36,7 @@ import java.util.List; // TODO: Make this a TestApi for CTS testing. public final class ExtrasUtils { + // Keys for response objects. private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data"; private static final String ENTITIES_EXTRAS = "entities-extras"; private static final String ACTION_INTENT = "action-intent"; @@ -48,6 +49,10 @@ public final class ExtrasUtils { private static final String TEXT_LANGUAGES = "text-languages"; private static final String ENTITIES = "entities"; + // Keys for request objects. + private static final String IS_SERIALIZED_ENTITY_DATA_ENABLED = + "is-serialized-entity-data-enabled"; + private ExtrasUtils() {} /** @@ -308,7 +313,23 @@ public final class ExtrasUtils { /** * Returns a list of entities contained in the {@code extra}. */ + @Nullable public static List getEntities(Bundle container) { return container.getParcelableArrayList(ENTITIES); } + + /** + * Whether the annotator should populate serialized entity data into the result object. + */ + public static boolean isSerializedEntityDataEnabled(TextLinks.Request request) { + return request.getExtras().getBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED); + } + + /** + * To indicate whether the annotator should populate serialized entity data in the result + * object. + */ + public static void putIsSerializedEntityDataEnabled(Bundle bundle, boolean isEnabled) { + bundle.putBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED, isEnabled); + } } diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java index ed6ec54986ad5..3297523b0da9f 100644 --- a/core/java/android/view/textclassifier/TextClassifierImpl.java +++ b/core/java/android/view/textclassifier/TextClassifierImpl.java @@ -307,6 +307,8 @@ public final class TextClassifierImpl implements TextClassifier { final String detectLanguageTags = detectLanguageTagsFromText(request.getText()); final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales()); + final boolean isSerializedEntityDataEnabled = + ExtrasUtils.isSerializedEntityDataEnabled(request); final AnnotatorModel.AnnotatedSpan[] annotations = annotatorImpl.annotate( textString, @@ -314,7 +316,10 @@ public final class TextClassifierImpl implements TextClassifier { refTime.toInstant().toEpochMilli(), refTime.getZone().getId(), localesString, - detectLanguageTags)); + detectLanguageTags, + entitiesToIdentify, + AnnotatorModel.AnnotationUsecase.SMART.getValue(), + isSerializedEntityDataEnabled)); for (AnnotatorModel.AnnotatedSpan span : annotations) { final AnnotatorModel.ClassificationResult[] results = span.getClassification(); @@ -326,7 +331,11 @@ public final class TextClassifierImpl implements TextClassifier { for (int i = 0; i < results.length; i++) { entityScores.put(results[i].getCollection(), results[i].getScore()); } - builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores); + Bundle extras = new Bundle(); + if (isSerializedEntityDataEnabled) { + ExtrasUtils.putEntities(extras, results); + } + builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras); } final TextLinks links = builder.build(); final long endTimeMs = System.currentTimeMillis(); diff --git a/core/tests/coretests/src/android/view/textclassifier/TextClassifierTest.java b/core/tests/coretests/src/android/view/textclassifier/TextClassifierTest.java index 79512a744f52a..aeb8949c69767 100644 --- a/core/tests/coretests/src/android/view/textclassifier/TextClassifierTest.java +++ b/core/tests/coretests/src/android/view/textclassifier/TextClassifierTest.java @@ -361,6 +361,38 @@ public class TextClassifierTest { mClassifier.generateLinks(request); } + @Test + public void testGenerateLinks_entityData() { + if (isTextClassifierDisabled()) return; + String text = "The number is +12122537077."; + Bundle extras = new Bundle(); + ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true); + TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build(); + + TextLinks textLinks = mClassifier.generateLinks(request); + + Truth.assertThat(textLinks.getLinks()).hasSize(1); + TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); + List entities = ExtrasUtils.getEntities(textLink.getExtras()); + Truth.assertThat(entities).hasSize(1); + Bundle entity = entities.get(0); + Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE); + } + + @Test + public void testGenerateLinks_entityData_disabled() { + if (isTextClassifierDisabled()) return; + String text = "The number is +12122537077."; + TextLinks.Request request = new TextLinks.Request.Builder(text).build(); + + TextLinks textLinks = mClassifier.generateLinks(request); + + Truth.assertThat(textLinks.getLinks()).hasSize(1); + TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); + List entities = ExtrasUtils.getEntities(textLink.getExtras()); + Truth.assertThat(entities).isNull(); + } + @Test public void testDetectLanguage() { if (isTextClassifierDisabled()) return;