Merge "Follow-up CL of ag/6935284, add entities to extras in generateLinks" into qt-dev
This commit is contained in:
committed by
Android (Google) Code Review
commit
59ed9a7f27
@@ -36,6 +36,7 @@ import java.util.List;
|
||||
// TODO: Make this a TestApi for CTS testing.
|
||||
public final class ExtrasUtils {
|
||||
|
||||
// Keys for response objects.
|
||||
private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
|
||||
private static final String ENTITIES_EXTRAS = "entities-extras";
|
||||
private static final String ACTION_INTENT = "action-intent";
|
||||
@@ -48,6 +49,10 @@ public final class ExtrasUtils {
|
||||
private static final String TEXT_LANGUAGES = "text-languages";
|
||||
private static final String ENTITIES = "entities";
|
||||
|
||||
// Keys for request objects.
|
||||
private static final String IS_SERIALIZED_ENTITY_DATA_ENABLED =
|
||||
"is-serialized-entity-data-enabled";
|
||||
|
||||
private ExtrasUtils() {}
|
||||
|
||||
/**
|
||||
@@ -308,7 +313,23 @@ public final class ExtrasUtils {
|
||||
/**
|
||||
* Returns a list of entities contained in the {@code extra}.
|
||||
*/
|
||||
@Nullable
|
||||
public static List<Bundle> getEntities(Bundle container) {
|
||||
return container.getParcelableArrayList(ENTITIES);
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the annotator should populate serialized entity data into the result object.
|
||||
*/
|
||||
public static boolean isSerializedEntityDataEnabled(TextLinks.Request request) {
|
||||
return request.getExtras().getBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED);
|
||||
}
|
||||
|
||||
/**
|
||||
* To indicate whether the annotator should populate serialized entity data in the result
|
||||
* object.
|
||||
*/
|
||||
public static void putIsSerializedEntityDataEnabled(Bundle bundle, boolean isEnabled) {
|
||||
bundle.putBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED, isEnabled);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,6 +307,8 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
|
||||
final AnnotatorModel annotatorImpl =
|
||||
getAnnotatorImpl(request.getDefaultLocales());
|
||||
final boolean isSerializedEntityDataEnabled =
|
||||
ExtrasUtils.isSerializedEntityDataEnabled(request);
|
||||
final AnnotatorModel.AnnotatedSpan[] annotations =
|
||||
annotatorImpl.annotate(
|
||||
textString,
|
||||
@@ -314,7 +316,10 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
refTime.toInstant().toEpochMilli(),
|
||||
refTime.getZone().getId(),
|
||||
localesString,
|
||||
detectLanguageTags));
|
||||
detectLanguageTags,
|
||||
entitiesToIdentify,
|
||||
AnnotatorModel.AnnotationUsecase.SMART.getValue(),
|
||||
isSerializedEntityDataEnabled));
|
||||
for (AnnotatorModel.AnnotatedSpan span : annotations) {
|
||||
final AnnotatorModel.ClassificationResult[] results =
|
||||
span.getClassification();
|
||||
@@ -326,7 +331,11 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
for (int i = 0; i < results.length; i++) {
|
||||
entityScores.put(results[i].getCollection(), results[i].getScore());
|
||||
}
|
||||
builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores);
|
||||
Bundle extras = new Bundle();
|
||||
if (isSerializedEntityDataEnabled) {
|
||||
ExtrasUtils.putEntities(extras, results);
|
||||
}
|
||||
builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
|
||||
}
|
||||
final TextLinks links = builder.build();
|
||||
final long endTimeMs = System.currentTimeMillis();
|
||||
|
||||
@@ -361,6 +361,38 @@ public class TextClassifierTest {
|
||||
mClassifier.generateLinks(request);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGenerateLinks_entityData() {
|
||||
if (isTextClassifierDisabled()) return;
|
||||
String text = "The number is +12122537077.";
|
||||
Bundle extras = new Bundle();
|
||||
ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
|
||||
TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
|
||||
|
||||
TextLinks textLinks = mClassifier.generateLinks(request);
|
||||
|
||||
Truth.assertThat(textLinks.getLinks()).hasSize(1);
|
||||
TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
|
||||
List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
|
||||
Truth.assertThat(entities).hasSize(1);
|
||||
Bundle entity = entities.get(0);
|
||||
Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGenerateLinks_entityData_disabled() {
|
||||
if (isTextClassifierDisabled()) return;
|
||||
String text = "The number is +12122537077.";
|
||||
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
|
||||
|
||||
TextLinks textLinks = mClassifier.generateLinks(request);
|
||||
|
||||
Truth.assertThat(textLinks.getLinks()).hasSize(1);
|
||||
TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
|
||||
List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
|
||||
Truth.assertThat(entities).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDetectLanguage() {
|
||||
if (isTextClassifierDisabled()) return;
|
||||
|
||||
Reference in New Issue
Block a user