Implement TextClassifierImpl.detectLanguage()
- Includes some fixes to handle null ParcelFileDescriptors. - Closes fds immediately after the model has been loaded. Bug: 116020587 Test: atest android.view.textclassifier.TextClassificationManagerTest Change-Id: Ieb05d081847ac218d2a5b46db95cd512838f67ab
This commit is contained in:
@@ -31,6 +31,7 @@ import android.content.Intent;
|
||||
import android.content.pm.PackageManager;
|
||||
import android.content.pm.ResolveInfo;
|
||||
import android.graphics.drawable.Icon;
|
||||
import android.icu.util.ULocale;
|
||||
import android.net.Uri;
|
||||
import android.os.Bundle;
|
||||
import android.os.LocaleList;
|
||||
@@ -45,6 +46,7 @@ import com.android.internal.util.IndentingPrintWriter;
|
||||
import com.android.internal.util.Preconditions;
|
||||
|
||||
import com.google.android.textclassifier.AnnotatorModel;
|
||||
import com.google.android.textclassifier.LangIdModel;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
@@ -83,6 +85,9 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
private static final String MODEL_FILE_REGEX = "textclassifier\\.(.*)\\.model";
|
||||
private static final String UPDATED_MODEL_FILE_PATH =
|
||||
"/data/misc/textclassifier/textclassifier.model";
|
||||
private static final String LANG_ID_MODEL_FILE_PATH = "/etc/textclassifier/lang_id.model";
|
||||
private static final String UPDATED_LANG_ID_MODEL_FILE_PATH =
|
||||
"/data/misc/textclassifier/lang_id.model";
|
||||
|
||||
private final Context mContext;
|
||||
private final TextClassifier mFallback;
|
||||
@@ -94,7 +99,9 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
@GuardedBy("mLock") // Do not access outside this lock.
|
||||
private ModelFile mModel;
|
||||
@GuardedBy("mLock") // Do not access outside this lock.
|
||||
private AnnotatorModel mNative;
|
||||
private AnnotatorModel mAnnotatorImpl;
|
||||
@GuardedBy("mLock") // Do not access outside this lock.
|
||||
private LangIdModel mLangIdImpl;
|
||||
|
||||
private final Object mLoggerLock = new Object();
|
||||
@GuardedBy("mLoggerLock") // Do not access outside this lock.
|
||||
@@ -127,14 +134,15 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
&& rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
|
||||
final String localesString = concatenateLocales(request.getDefaultLocales());
|
||||
final ZonedDateTime refTime = ZonedDateTime.now();
|
||||
final AnnotatorModel nativeImpl = getNative(request.getDefaultLocales());
|
||||
final AnnotatorModel annotatorImpl =
|
||||
getAnnotatorImpl(request.getDefaultLocales());
|
||||
final int start;
|
||||
final int end;
|
||||
if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) {
|
||||
start = request.getStartIndex();
|
||||
end = request.getEndIndex();
|
||||
} else {
|
||||
final int[] startEnd = nativeImpl.suggestSelection(
|
||||
final int[] startEnd = annotatorImpl.suggestSelection(
|
||||
string, request.getStartIndex(), request.getEndIndex(),
|
||||
new AnnotatorModel.SelectionOptions(localesString));
|
||||
start = startEnd[0];
|
||||
@@ -145,7 +153,7 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
&& start <= request.getStartIndex() && end >= request.getEndIndex()) {
|
||||
final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
|
||||
final AnnotatorModel.ClassificationResult[] results =
|
||||
nativeImpl.classifyText(
|
||||
annotatorImpl.classifyText(
|
||||
string, start, end,
|
||||
new AnnotatorModel.ClassificationOptions(
|
||||
refTime.toInstant().toEpochMilli(),
|
||||
@@ -187,7 +195,7 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
final ZonedDateTime refTime = request.getReferenceTime() != null
|
||||
? request.getReferenceTime() : ZonedDateTime.now();
|
||||
final AnnotatorModel.ClassificationResult[] results =
|
||||
getNative(request.getDefaultLocales())
|
||||
getAnnotatorImpl(request.getDefaultLocales())
|
||||
.classifyText(
|
||||
string, request.getStartIndex(), request.getEndIndex(),
|
||||
new AnnotatorModel.ClassificationOptions(
|
||||
@@ -230,10 +238,10 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
? request.getEntityConfig().resolveEntityListModifications(
|
||||
getEntitiesForHints(request.getEntityConfig().getHints()))
|
||||
: mSettings.getEntityListDefault();
|
||||
final AnnotatorModel nativeImpl =
|
||||
getNative(request.getDefaultLocales());
|
||||
final AnnotatorModel annotatorImpl =
|
||||
getAnnotatorImpl(request.getDefaultLocales());
|
||||
final AnnotatorModel.AnnotatedSpan[] annotations =
|
||||
nativeImpl.annotate(
|
||||
annotatorImpl.annotate(
|
||||
textString,
|
||||
new AnnotatorModel.AnnotationOptions(
|
||||
refTime.toInstant().toEpochMilli(),
|
||||
@@ -288,6 +296,7 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
}
|
||||
}
|
||||
|
||||
/** @inheritDoc */
|
||||
@Override
|
||||
public void onSelectionEvent(SelectionEvent event) {
|
||||
Preconditions.checkNotNull(event);
|
||||
@@ -299,7 +308,29 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
}
|
||||
}
|
||||
|
||||
private AnnotatorModel getNative(LocaleList localeList)
|
||||
/** @inheritDoc */
|
||||
@Override
|
||||
public TextLanguage detectLanguage(@NonNull TextLanguage.Request request) {
|
||||
Preconditions.checkNotNull(request);
|
||||
Utils.checkMainThread();
|
||||
try {
|
||||
final TextLanguage.Builder builder = new TextLanguage.Builder();
|
||||
final LangIdModel.LanguageResult[] langResults =
|
||||
getLangIdImpl().detectLanguages(request.getText().toString());
|
||||
for (int i = 0; i < langResults.length; i++) {
|
||||
builder.putLocale(
|
||||
ULocale.forLanguageTag(langResults[i].getLanguage()),
|
||||
langResults[i].getScore());
|
||||
}
|
||||
return builder.build();
|
||||
} catch (Throwable t) {
|
||||
// Avoid throwing from this method. Log the error.
|
||||
Log.e(LOG_TAG, "Error detecting text language.", t);
|
||||
}
|
||||
return mFallback.detectLanguage(request);
|
||||
}
|
||||
|
||||
private AnnotatorModel getAnnotatorImpl(LocaleList localeList)
|
||||
throws FileNotFoundException {
|
||||
synchronized (mLock) {
|
||||
localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
|
||||
@@ -307,16 +338,72 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
if (bestModel == null) {
|
||||
throw new FileNotFoundException("No model for " + localeList.toLanguageTags());
|
||||
}
|
||||
if (mNative == null || !Objects.equals(mModel, bestModel)) {
|
||||
if (mAnnotatorImpl == null || !Objects.equals(mModel, bestModel)) {
|
||||
Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
|
||||
destroyNativeIfExistsLocked();
|
||||
destroyAnnotatorImplIfExistsLocked();
|
||||
final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
|
||||
new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
|
||||
mNative = new AnnotatorModel(fd.getFd());
|
||||
closeAndLogError(fd);
|
||||
mModel = bestModel;
|
||||
try {
|
||||
if (fd != null) {
|
||||
mAnnotatorImpl = new AnnotatorModel(fd.getFd());
|
||||
mModel = bestModel;
|
||||
}
|
||||
} finally {
|
||||
maybeCloseAndLogError(fd);
|
||||
}
|
||||
}
|
||||
return mNative;
|
||||
return mAnnotatorImpl;
|
||||
}
|
||||
}
|
||||
|
||||
@GuardedBy("mLock") // Do not call outside this lock.
|
||||
private void destroyAnnotatorImplIfExistsLocked() {
|
||||
if (mAnnotatorImpl != null) {
|
||||
mAnnotatorImpl.close();
|
||||
mAnnotatorImpl = null;
|
||||
}
|
||||
}
|
||||
|
||||
private LangIdModel getLangIdImpl() throws FileNotFoundException {
|
||||
synchronized (mLock) {
|
||||
if (mLangIdImpl == null) {
|
||||
ParcelFileDescriptor factoryFd = null;
|
||||
ParcelFileDescriptor updateFd = null;
|
||||
try {
|
||||
int factoryVersion = -1;
|
||||
int updateVersion = factoryVersion;
|
||||
final File factoryFile = new File(LANG_ID_MODEL_FILE_PATH);
|
||||
if (factoryFile.exists()) {
|
||||
factoryFd = ParcelFileDescriptor.open(
|
||||
factoryFile, ParcelFileDescriptor.MODE_READ_ONLY);
|
||||
// TODO: Uncomment when method is implemented:
|
||||
// if (factoryFd != null) {
|
||||
// factoryVersion = LangIdModel.getVersion(factoryFd.getFd());
|
||||
// }
|
||||
}
|
||||
final File updateFile = new File(UPDATED_LANG_ID_MODEL_FILE_PATH);
|
||||
if (updateFile.exists()) {
|
||||
updateFd = ParcelFileDescriptor.open(
|
||||
updateFile, ParcelFileDescriptor.MODE_READ_ONLY);
|
||||
// TODO: Uncomment when method is implemented:
|
||||
// if (updateFd != null) {
|
||||
// updateVersion = LangIdModel.getVersion(updateFd.getFd());
|
||||
// }
|
||||
}
|
||||
|
||||
if (updateVersion > factoryVersion) {
|
||||
mLangIdImpl = new LangIdModel(updateFd.getFd());
|
||||
} else if (factoryFd != null) {
|
||||
mLangIdImpl = new LangIdModel(factoryFd.getFd());
|
||||
} else {
|
||||
throw new FileNotFoundException("Language detection model not found");
|
||||
}
|
||||
} finally {
|
||||
maybeCloseAndLogError(factoryFd);
|
||||
maybeCloseAndLogError(updateFd);
|
||||
}
|
||||
}
|
||||
return mLangIdImpl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,14 +414,6 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
}
|
||||
}
|
||||
|
||||
@GuardedBy("mLock") // Do not call outside this lock.
|
||||
private void destroyNativeIfExistsLocked() {
|
||||
if (mNative != null) {
|
||||
mNative.close();
|
||||
mNative = null;
|
||||
}
|
||||
}
|
||||
|
||||
private static String concatenateLocales(@Nullable LocaleList locales) {
|
||||
return (locales == null) ? "" : locales.toLanguageTags();
|
||||
}
|
||||
@@ -407,20 +486,19 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
.setText(classifiedText);
|
||||
|
||||
final int size = classifications.length;
|
||||
AnnotatorModel.ClassificationResult highestScoringResult = null;
|
||||
float highestScore = Float.MIN_VALUE;
|
||||
AnnotatorModel.ClassificationResult highestScoringResult =
|
||||
size > 0 ? classifications[0] : null;
|
||||
for (int i = 0; i < size; i++) {
|
||||
builder.setEntityType(classifications[i].getCollection(),
|
||||
classifications[i].getScore());
|
||||
if (classifications[i].getScore() > highestScore) {
|
||||
if (classifications[i].getScore() > highestScoringResult.getScore()) {
|
||||
highestScoringResult = classifications[i];
|
||||
highestScore = classifications[i].getScore();
|
||||
}
|
||||
}
|
||||
|
||||
boolean isPrimaryAction = true;
|
||||
for (LabeledIntent labeledIntent : IntentFactory.create(
|
||||
mContext, referenceTime, highestScoringResult, classifiedText)) {
|
||||
mContext, classifiedText, referenceTime, highestScoringResult)) {
|
||||
final RemoteAction action = labeledIntent.asRemoteAction(mContext);
|
||||
if (action == null) {
|
||||
continue;
|
||||
@@ -461,9 +539,13 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes the ParcelFileDescriptor and logs any errors that occur.
|
||||
* Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur.
|
||||
*/
|
||||
private static void closeAndLogError(ParcelFileDescriptor fd) {
|
||||
private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
|
||||
if (fd == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
fd.close();
|
||||
} catch (IOException e) {
|
||||
@@ -485,12 +567,17 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
/** Returns null if the path did not point to a compatible model. */
|
||||
static @Nullable ModelFile fromPath(String path) {
|
||||
final File file = new File(path);
|
||||
if (!file.exists()) {
|
||||
return null;
|
||||
}
|
||||
ParcelFileDescriptor modelFd = null;
|
||||
try {
|
||||
final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
|
||||
file, ParcelFileDescriptor.MODE_READ_ONLY);
|
||||
modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
|
||||
if (modelFd == null) {
|
||||
return null;
|
||||
}
|
||||
final int version = AnnotatorModel.getVersion(modelFd.getFd());
|
||||
final String supportedLocalesStr =
|
||||
AnnotatorModel.getLocales(modelFd.getFd());
|
||||
final String supportedLocalesStr = AnnotatorModel.getLocales(modelFd.getFd());
|
||||
if (supportedLocalesStr.isEmpty()) {
|
||||
Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
|
||||
return null;
|
||||
@@ -500,12 +587,13 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
for (String langTag : supportedLocalesStr.split(",")) {
|
||||
supportedLocales.add(Locale.forLanguageTag(langTag));
|
||||
}
|
||||
closeAndLogError(modelFd);
|
||||
return new ModelFile(path, file.getName(), version, supportedLocales,
|
||||
languageIndependent);
|
||||
} catch (FileNotFoundException e) {
|
||||
Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e);
|
||||
return null;
|
||||
} finally {
|
||||
maybeCloseAndLogError(modelFd);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -557,12 +645,12 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
public boolean equals(Object other) {
|
||||
if (this == other) {
|
||||
return true;
|
||||
} else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) {
|
||||
return false;
|
||||
} else {
|
||||
}
|
||||
if (other instanceof ModelFile) {
|
||||
final ModelFile otherModel = (ModelFile) other;
|
||||
return mPath.equals(otherModel.mPath);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -677,10 +765,12 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
@NonNull
|
||||
public static List<LabeledIntent> create(
|
||||
Context context,
|
||||
String text,
|
||||
@Nullable Instant referenceTime,
|
||||
AnnotatorModel.ClassificationResult classification,
|
||||
String text) {
|
||||
final String type = classification.getCollection().trim().toLowerCase(Locale.ENGLISH);
|
||||
@Nullable AnnotatorModel.ClassificationResult classification) {
|
||||
final String type = classification != null
|
||||
? classification.getCollection().trim().toLowerCase(Locale.ENGLISH)
|
||||
: null;
|
||||
text = text.trim();
|
||||
switch (type) {
|
||||
case TextClassifier.TYPE_EMAIL:
|
||||
|
||||
Reference in New Issue
Block a user