Implement TextClassifierImpl.detectLanguage()

- Includes some fixes to handle null ParcelFileDescriptors.
- Closes fds immediately after the model has been loaded.

Bug: 116020587
Test: atest android.view.textclassifier.TextClassificationManagerTest
Change-Id: Ieb05d081847ac218d2a5b46db95cd512838f67ab
This commit is contained in:
Abodunrinwa Toki
2018-10-19 20:58:26 +01:00
parent c2896a27fa
commit ee3a48eec0
2 changed files with 168 additions and 41 deletions

View File

@@ -31,6 +31,7 @@ import android.content.Intent;
import android.content.pm.PackageManager;
import android.content.pm.ResolveInfo;
import android.graphics.drawable.Icon;
import android.icu.util.ULocale;
import android.net.Uri;
import android.os.Bundle;
import android.os.LocaleList;
@@ -45,6 +46,7 @@ import com.android.internal.util.IndentingPrintWriter;
import com.android.internal.util.Preconditions;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;
import java.io.File;
import java.io.FileNotFoundException;
@@ -83,6 +85,9 @@ public final class TextClassifierImpl implements TextClassifier {
private static final String MODEL_FILE_REGEX = "textclassifier\\.(.*)\\.model";
private static final String UPDATED_MODEL_FILE_PATH =
"/data/misc/textclassifier/textclassifier.model";
private static final String LANG_ID_MODEL_FILE_PATH = "/etc/textclassifier/lang_id.model";
private static final String UPDATED_LANG_ID_MODEL_FILE_PATH =
"/data/misc/textclassifier/lang_id.model";
private final Context mContext;
private final TextClassifier mFallback;
@@ -94,7 +99,9 @@ public final class TextClassifierImpl implements TextClassifier {
@GuardedBy("mLock") // Do not access outside this lock.
private ModelFile mModel;
@GuardedBy("mLock") // Do not access outside this lock.
private AnnotatorModel mNative;
private AnnotatorModel mAnnotatorImpl;
@GuardedBy("mLock") // Do not access outside this lock.
private LangIdModel mLangIdImpl;
private final Object mLoggerLock = new Object();
@GuardedBy("mLoggerLock") // Do not access outside this lock.
@@ -127,14 +134,15 @@ public final class TextClassifierImpl implements TextClassifier {
&& rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
final String localesString = concatenateLocales(request.getDefaultLocales());
final ZonedDateTime refTime = ZonedDateTime.now();
final AnnotatorModel nativeImpl = getNative(request.getDefaultLocales());
final AnnotatorModel annotatorImpl =
getAnnotatorImpl(request.getDefaultLocales());
final int start;
final int end;
if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) {
start = request.getStartIndex();
end = request.getEndIndex();
} else {
final int[] startEnd = nativeImpl.suggestSelection(
final int[] startEnd = annotatorImpl.suggestSelection(
string, request.getStartIndex(), request.getEndIndex(),
new AnnotatorModel.SelectionOptions(localesString));
start = startEnd[0];
@@ -145,7 +153,7 @@ public final class TextClassifierImpl implements TextClassifier {
&& start <= request.getStartIndex() && end >= request.getEndIndex()) {
final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
final AnnotatorModel.ClassificationResult[] results =
nativeImpl.classifyText(
annotatorImpl.classifyText(
string, start, end,
new AnnotatorModel.ClassificationOptions(
refTime.toInstant().toEpochMilli(),
@@ -187,7 +195,7 @@ public final class TextClassifierImpl implements TextClassifier {
final ZonedDateTime refTime = request.getReferenceTime() != null
? request.getReferenceTime() : ZonedDateTime.now();
final AnnotatorModel.ClassificationResult[] results =
getNative(request.getDefaultLocales())
getAnnotatorImpl(request.getDefaultLocales())
.classifyText(
string, request.getStartIndex(), request.getEndIndex(),
new AnnotatorModel.ClassificationOptions(
@@ -230,10 +238,10 @@ public final class TextClassifierImpl implements TextClassifier {
? request.getEntityConfig().resolveEntityListModifications(
getEntitiesForHints(request.getEntityConfig().getHints()))
: mSettings.getEntityListDefault();
final AnnotatorModel nativeImpl =
getNative(request.getDefaultLocales());
final AnnotatorModel annotatorImpl =
getAnnotatorImpl(request.getDefaultLocales());
final AnnotatorModel.AnnotatedSpan[] annotations =
nativeImpl.annotate(
annotatorImpl.annotate(
textString,
new AnnotatorModel.AnnotationOptions(
refTime.toInstant().toEpochMilli(),
@@ -288,6 +296,7 @@ public final class TextClassifierImpl implements TextClassifier {
}
}
/** @inheritDoc */
@Override
public void onSelectionEvent(SelectionEvent event) {
Preconditions.checkNotNull(event);
@@ -299,7 +308,29 @@ public final class TextClassifierImpl implements TextClassifier {
}
}
private AnnotatorModel getNative(LocaleList localeList)
/** @inheritDoc */
@Override
public TextLanguage detectLanguage(@NonNull TextLanguage.Request request) {
Preconditions.checkNotNull(request);
Utils.checkMainThread();
try {
final TextLanguage.Builder builder = new TextLanguage.Builder();
final LangIdModel.LanguageResult[] langResults =
getLangIdImpl().detectLanguages(request.getText().toString());
for (int i = 0; i < langResults.length; i++) {
builder.putLocale(
ULocale.forLanguageTag(langResults[i].getLanguage()),
langResults[i].getScore());
}
return builder.build();
} catch (Throwable t) {
// Avoid throwing from this method. Log the error.
Log.e(LOG_TAG, "Error detecting text language.", t);
}
return mFallback.detectLanguage(request);
}
private AnnotatorModel getAnnotatorImpl(LocaleList localeList)
throws FileNotFoundException {
synchronized (mLock) {
localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
@@ -307,16 +338,72 @@ public final class TextClassifierImpl implements TextClassifier {
if (bestModel == null) {
throw new FileNotFoundException("No model for " + localeList.toLanguageTags());
}
if (mNative == null || !Objects.equals(mModel, bestModel)) {
if (mAnnotatorImpl == null || !Objects.equals(mModel, bestModel)) {
Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
destroyNativeIfExistsLocked();
destroyAnnotatorImplIfExistsLocked();
final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
mNative = new AnnotatorModel(fd.getFd());
closeAndLogError(fd);
mModel = bestModel;
try {
if (fd != null) {
mAnnotatorImpl = new AnnotatorModel(fd.getFd());
mModel = bestModel;
}
} finally {
maybeCloseAndLogError(fd);
}
}
return mNative;
return mAnnotatorImpl;
}
}
@GuardedBy("mLock") // Do not call outside this lock.
private void destroyAnnotatorImplIfExistsLocked() {
if (mAnnotatorImpl != null) {
mAnnotatorImpl.close();
mAnnotatorImpl = null;
}
}
private LangIdModel getLangIdImpl() throws FileNotFoundException {
synchronized (mLock) {
if (mLangIdImpl == null) {
ParcelFileDescriptor factoryFd = null;
ParcelFileDescriptor updateFd = null;
try {
int factoryVersion = -1;
int updateVersion = factoryVersion;
final File factoryFile = new File(LANG_ID_MODEL_FILE_PATH);
if (factoryFile.exists()) {
factoryFd = ParcelFileDescriptor.open(
factoryFile, ParcelFileDescriptor.MODE_READ_ONLY);
// TODO: Uncomment when method is implemented:
// if (factoryFd != null) {
// factoryVersion = LangIdModel.getVersion(factoryFd.getFd());
// }
}
final File updateFile = new File(UPDATED_LANG_ID_MODEL_FILE_PATH);
if (updateFile.exists()) {
updateFd = ParcelFileDescriptor.open(
updateFile, ParcelFileDescriptor.MODE_READ_ONLY);
// TODO: Uncomment when method is implemented:
// if (updateFd != null) {
// updateVersion = LangIdModel.getVersion(updateFd.getFd());
// }
}
if (updateVersion > factoryVersion) {
mLangIdImpl = new LangIdModel(updateFd.getFd());
} else if (factoryFd != null) {
mLangIdImpl = new LangIdModel(factoryFd.getFd());
} else {
throw new FileNotFoundException("Language detection model not found");
}
} finally {
maybeCloseAndLogError(factoryFd);
maybeCloseAndLogError(updateFd);
}
}
return mLangIdImpl;
}
}
@@ -327,14 +414,6 @@ public final class TextClassifierImpl implements TextClassifier {
}
}
@GuardedBy("mLock") // Do not call outside this lock.
private void destroyNativeIfExistsLocked() {
if (mNative != null) {
mNative.close();
mNative = null;
}
}
private static String concatenateLocales(@Nullable LocaleList locales) {
return (locales == null) ? "" : locales.toLanguageTags();
}
@@ -407,20 +486,19 @@ public final class TextClassifierImpl implements TextClassifier {
.setText(classifiedText);
final int size = classifications.length;
AnnotatorModel.ClassificationResult highestScoringResult = null;
float highestScore = Float.MIN_VALUE;
AnnotatorModel.ClassificationResult highestScoringResult =
size > 0 ? classifications[0] : null;
for (int i = 0; i < size; i++) {
builder.setEntityType(classifications[i].getCollection(),
classifications[i].getScore());
if (classifications[i].getScore() > highestScore) {
if (classifications[i].getScore() > highestScoringResult.getScore()) {
highestScoringResult = classifications[i];
highestScore = classifications[i].getScore();
}
}
boolean isPrimaryAction = true;
for (LabeledIntent labeledIntent : IntentFactory.create(
mContext, referenceTime, highestScoringResult, classifiedText)) {
mContext, classifiedText, referenceTime, highestScoringResult)) {
final RemoteAction action = labeledIntent.asRemoteAction(mContext);
if (action == null) {
continue;
@@ -461,9 +539,13 @@ public final class TextClassifierImpl implements TextClassifier {
}
/**
* Closes the ParcelFileDescriptor and logs any errors that occur.
* Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur.
*/
private static void closeAndLogError(ParcelFileDescriptor fd) {
private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
if (fd == null) {
return;
}
try {
fd.close();
} catch (IOException e) {
@@ -485,12 +567,17 @@ public final class TextClassifierImpl implements TextClassifier {
/** Returns null if the path did not point to a compatible model. */
static @Nullable ModelFile fromPath(String path) {
final File file = new File(path);
if (!file.exists()) {
return null;
}
ParcelFileDescriptor modelFd = null;
try {
final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
file, ParcelFileDescriptor.MODE_READ_ONLY);
modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
if (modelFd == null) {
return null;
}
final int version = AnnotatorModel.getVersion(modelFd.getFd());
final String supportedLocalesStr =
AnnotatorModel.getLocales(modelFd.getFd());
final String supportedLocalesStr = AnnotatorModel.getLocales(modelFd.getFd());
if (supportedLocalesStr.isEmpty()) {
Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
return null;
@@ -500,12 +587,13 @@ public final class TextClassifierImpl implements TextClassifier {
for (String langTag : supportedLocalesStr.split(",")) {
supportedLocales.add(Locale.forLanguageTag(langTag));
}
closeAndLogError(modelFd);
return new ModelFile(path, file.getName(), version, supportedLocales,
languageIndependent);
} catch (FileNotFoundException e) {
Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e);
return null;
} finally {
maybeCloseAndLogError(modelFd);
}
}
@@ -557,12 +645,12 @@ public final class TextClassifierImpl implements TextClassifier {
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) {
return false;
} else {
}
if (other instanceof ModelFile) {
final ModelFile otherModel = (ModelFile) other;
return mPath.equals(otherModel.mPath);
}
return false;
}
@Override
@@ -677,10 +765,12 @@ public final class TextClassifierImpl implements TextClassifier {
@NonNull
public static List<LabeledIntent> create(
Context context,
String text,
@Nullable Instant referenceTime,
AnnotatorModel.ClassificationResult classification,
String text) {
final String type = classification.getCollection().trim().toLowerCase(Locale.ENGLISH);
@Nullable AnnotatorModel.ClassificationResult classification) {
final String type = classification != null
? classification.getCollection().trim().toLowerCase(Locale.ENGLISH)
: null;
text = text.trim();
switch (type) {
case TextClassifier.TYPE_EMAIL: