Merge "Allowing models to support multiple languages"
This commit is contained in:
committed by
Android (Google) Code Review
commit
ea26b2a470
@@ -108,9 +108,9 @@ final class SmartSelection {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the language of the model.
|
||||
* Returns a comma separated list of locales supported by the model as BCP 47 tags.
|
||||
*/
|
||||
public static String getLanguage(int fd) {
|
||||
public static String getLanguages(int fd) {
|
||||
return nativeGetLanguage(fd);
|
||||
}
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.StringJoiner;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
@@ -101,11 +102,9 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
|
||||
private final Object mLock = new Object();
|
||||
@GuardedBy("mLock") // Do not access outside this lock.
|
||||
private Map<Locale, String> mModelFilePaths;
|
||||
private List<ModelFile> mAllModelFiles;
|
||||
@GuardedBy("mLock") // Do not access outside this lock.
|
||||
private Locale mLocale;
|
||||
@GuardedBy("mLock") // Do not access outside this lock.
|
||||
private int mVersion;
|
||||
private ModelFile mModel;
|
||||
@GuardedBy("mLock") // Do not access outside this lock.
|
||||
private SmartSelection mSmartSelection;
|
||||
|
||||
@@ -281,18 +280,18 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
private SmartSelection getSmartSelection(LocaleList localeList) throws FileNotFoundException {
|
||||
synchronized (mLock) {
|
||||
localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
|
||||
final Locale locale = findBestSupportedLocaleLocked(localeList);
|
||||
if (locale == null) {
|
||||
throw new FileNotFoundException("No file for null locale");
|
||||
final ModelFile bestModel = findBestModelLocked(localeList);
|
||||
if (bestModel == null) {
|
||||
throw new FileNotFoundException("No model for " + localeList.toLanguageTags());
|
||||
}
|
||||
if (mSmartSelection == null || !Objects.equals(mLocale, locale)) {
|
||||
if (mSmartSelection == null || !Objects.equals(mModel, bestModel)) {
|
||||
Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
|
||||
destroySmartSelectionIfExistsLocked();
|
||||
final ParcelFileDescriptor fd = getFdLocked(locale);
|
||||
final int modelFd = fd.getFd();
|
||||
mVersion = SmartSelection.getVersion(modelFd);
|
||||
mSmartSelection = new SmartSelection(modelFd);
|
||||
final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
|
||||
new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
|
||||
mSmartSelection = new SmartSelection(fd.getFd());
|
||||
closeAndLogError(fd);
|
||||
mLocale = locale;
|
||||
mModel = bestModel;
|
||||
}
|
||||
return mSmartSelection;
|
||||
}
|
||||
@@ -300,74 +299,8 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
|
||||
private String getSignature(String text, int start, int end) {
|
||||
synchronized (mLock) {
|
||||
return DefaultLogger.createSignature(text, start, end, mContext, mVersion, mLocale);
|
||||
}
|
||||
}
|
||||
|
||||
@GuardedBy("mLock") // Do not call outside this lock.
|
||||
private ParcelFileDescriptor getFdLocked(Locale locale) throws FileNotFoundException {
|
||||
ParcelFileDescriptor updateFd;
|
||||
int updateVersion = -1;
|
||||
try {
|
||||
updateFd = ParcelFileDescriptor.open(
|
||||
new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY);
|
||||
if (updateFd != null) {
|
||||
updateVersion = SmartSelection.getVersion(updateFd.getFd());
|
||||
}
|
||||
} catch (FileNotFoundException e) {
|
||||
updateFd = null;
|
||||
}
|
||||
ParcelFileDescriptor factoryFd;
|
||||
int factoryVersion = -1;
|
||||
try {
|
||||
final String factoryModelFilePath = getFactoryModelFilePathsLocked().get(locale);
|
||||
if (factoryModelFilePath != null) {
|
||||
factoryFd = ParcelFileDescriptor.open(
|
||||
new File(factoryModelFilePath), ParcelFileDescriptor.MODE_READ_ONLY);
|
||||
if (factoryFd != null) {
|
||||
factoryVersion = SmartSelection.getVersion(factoryFd.getFd());
|
||||
}
|
||||
} else {
|
||||
factoryFd = null;
|
||||
}
|
||||
} catch (FileNotFoundException e) {
|
||||
factoryFd = null;
|
||||
}
|
||||
|
||||
if (updateFd == null) {
|
||||
if (factoryFd != null) {
|
||||
return factoryFd;
|
||||
} else {
|
||||
throw new FileNotFoundException(
|
||||
String.format(Locale.US, "No model file found for %s", locale));
|
||||
}
|
||||
}
|
||||
|
||||
final int updateFdInt = updateFd.getFd();
|
||||
final boolean localeMatches = Objects.equals(
|
||||
locale.getLanguage().trim().toLowerCase(),
|
||||
SmartSelection.getLanguage(updateFdInt).trim().toLowerCase());
|
||||
if (factoryFd == null) {
|
||||
if (localeMatches) {
|
||||
return updateFd;
|
||||
} else {
|
||||
closeAndLogError(updateFd);
|
||||
throw new FileNotFoundException(
|
||||
String.format(Locale.US, "No model file found for %s", locale));
|
||||
}
|
||||
}
|
||||
|
||||
if (!localeMatches) {
|
||||
closeAndLogError(updateFd);
|
||||
return factoryFd;
|
||||
}
|
||||
|
||||
if (updateVersion > factoryVersion) {
|
||||
closeAndLogError(factoryFd);
|
||||
return updateFd;
|
||||
} else {
|
||||
closeAndLogError(updateFd);
|
||||
return factoryFd;
|
||||
return DefaultLogger.createSignature(text, start, end, mContext, mModel.getVersion(),
|
||||
mModel.getSupportedLocales());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -379,60 +312,66 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the most appropriate model to use for the given target locale list.
|
||||
*
|
||||
* The basic logic is: we ignore all models that don't support any of the target locales. For
|
||||
* the remaining candidates, we take the update model unless its version number is lower than
|
||||
* the factory version. It's assumed that factory models do not have overlapping locale ranges
|
||||
* and conflict resolution between these models hence doesn't matter.
|
||||
*/
|
||||
@GuardedBy("mLock") // Do not call outside this lock.
|
||||
@Nullable
|
||||
private Locale findBestSupportedLocaleLocked(LocaleList localeList) {
|
||||
private ModelFile findBestModelLocked(LocaleList localeList) {
|
||||
// Specified localeList takes priority over the system default, so it is listed first.
|
||||
final String languages = localeList.isEmpty()
|
||||
? LocaleList.getDefault().toLanguageTags()
|
||||
: localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags();
|
||||
final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
|
||||
|
||||
final List<Locale> supportedLocales =
|
||||
new ArrayList<>(getFactoryModelFilePathsLocked().keySet());
|
||||
final Locale updatedModelLocale = getUpdatedModelLocale();
|
||||
if (updatedModelLocale != null) {
|
||||
supportedLocales.add(updatedModelLocale);
|
||||
ModelFile bestModel = null;
|
||||
int bestModelVersion = -1;
|
||||
for (ModelFile model : listAllModelsLocked()) {
|
||||
if (model.isAnyLanguageSupported(languageRangeList)) {
|
||||
if (model.getVersion() >= bestModelVersion) {
|
||||
bestModel = model;
|
||||
bestModelVersion = model.getVersion();
|
||||
}
|
||||
}
|
||||
}
|
||||
return Locale.lookup(languageRangeList, supportedLocales);
|
||||
return bestModel;
|
||||
}
|
||||
|
||||
/** Returns a list of all model files available, in order of precedence. */
|
||||
@GuardedBy("mLock") // Do not call outside this lock.
|
||||
private Map<Locale, String> getFactoryModelFilePathsLocked() {
|
||||
if (mModelFilePaths == null) {
|
||||
final Map<Locale, String> modelFilePaths = new HashMap<>();
|
||||
private List<ModelFile> listAllModelsLocked() {
|
||||
if (mAllModelFiles == null) {
|
||||
final List<ModelFile> allModels = new ArrayList<>();
|
||||
// The update model has the highest precedence.
|
||||
if (new File(UPDATED_MODEL_FILE_PATH).exists()) {
|
||||
final ModelFile updatedModel = ModelFile.fromPath(UPDATED_MODEL_FILE_PATH);
|
||||
if (updatedModel != null) {
|
||||
allModels.add(updatedModel);
|
||||
}
|
||||
}
|
||||
// Factory models should never have overlapping locales, so the order doesn't matter.
|
||||
final File modelsDir = new File(MODEL_DIR);
|
||||
if (modelsDir.exists() && modelsDir.isDirectory()) {
|
||||
final File[] models = modelsDir.listFiles();
|
||||
final File[] modelFiles = modelsDir.listFiles();
|
||||
final Pattern modelFilenamePattern = Pattern.compile(MODEL_FILE_REGEX);
|
||||
final int size = models.length;
|
||||
for (int i = 0; i < size; i++) {
|
||||
final File modelFile = models[i];
|
||||
for (File modelFile : modelFiles) {
|
||||
final Matcher matcher = modelFilenamePattern.matcher(modelFile.getName());
|
||||
if (matcher.matches() && modelFile.isFile()) {
|
||||
final String language = matcher.group(1);
|
||||
final Locale locale = Locale.forLanguageTag(language);
|
||||
modelFilePaths.put(locale, modelFile.getAbsolutePath());
|
||||
final ModelFile model = ModelFile.fromPath(modelFile.getAbsolutePath());
|
||||
if (model != null) {
|
||||
allModels.add(model);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mModelFilePaths = modelFilePaths;
|
||||
}
|
||||
return mModelFilePaths;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private Locale getUpdatedModelLocale() {
|
||||
try {
|
||||
final ParcelFileDescriptor updateFd = ParcelFileDescriptor.open(
|
||||
new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY);
|
||||
final Locale locale = Locale.forLanguageTag(
|
||||
SmartSelection.getLanguage(updateFd.getFd()));
|
||||
closeAndLogError(updateFd);
|
||||
return locale;
|
||||
} catch (FileNotFoundException e) {
|
||||
return null;
|
||||
mAllModelFiles = allModels;
|
||||
}
|
||||
return mAllModelFiles;
|
||||
}
|
||||
|
||||
private TextClassification createClassificationResult(
|
||||
@@ -521,6 +460,95 @@ public final class TextClassifierImpl implements TextClassifier {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Describes TextClassifier model files on disk.
|
||||
*/
|
||||
private static final class ModelFile {
|
||||
|
||||
private final String mPath;
|
||||
private final String mName;
|
||||
private final int mVersion;
|
||||
private final List<Locale> mSupportedLocales;
|
||||
|
||||
/** Returns null if the path did not point to a compatible model. */
|
||||
static @Nullable ModelFile fromPath(String path) {
|
||||
final File file = new File(path);
|
||||
try {
|
||||
final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
|
||||
file, ParcelFileDescriptor.MODE_READ_ONLY);
|
||||
final int version = SmartSelection.getVersion(modelFd.getFd());
|
||||
final String supportedLocalesStr = SmartSelection.getLanguages(modelFd.getFd());
|
||||
if (supportedLocalesStr.isEmpty()) {
|
||||
Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
|
||||
return null;
|
||||
}
|
||||
final List<Locale> supportedLocales = new ArrayList<>();
|
||||
for (String langTag : supportedLocalesStr.split(",")) {
|
||||
supportedLocales.add(Locale.forLanguageTag(langTag));
|
||||
}
|
||||
closeAndLogError(modelFd);
|
||||
return new ModelFile(path, file.getName(), version, supportedLocales);
|
||||
} catch (FileNotFoundException e) {
|
||||
Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/** The absolute path to the model file. */
|
||||
String getPath() {
|
||||
return mPath;
|
||||
}
|
||||
|
||||
/** A name to use for signature generation. Effectively the name of the model file. */
|
||||
String getName() {
|
||||
return mName;
|
||||
}
|
||||
|
||||
/** Returns the version tag in the model's metadata. */
|
||||
int getVersion() {
|
||||
return mVersion;
|
||||
}
|
||||
|
||||
/** Returns whether the language supports any language in the given ranges. */
|
||||
boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
|
||||
return Locale.lookup(languageRanges, mSupportedLocales) != null;
|
||||
}
|
||||
|
||||
/** All locales supported by the model. */
|
||||
List<Locale> getSupportedLocales() {
|
||||
return Collections.unmodifiableList(mSupportedLocales);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
if (this == other) {
|
||||
return true;
|
||||
} else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) {
|
||||
return false;
|
||||
} else {
|
||||
final ModelFile otherModel = (ModelFile) other;
|
||||
return mPath.equals(otherModel.mPath);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
final StringJoiner localesJoiner = new StringJoiner(",");
|
||||
for (Locale locale : mSupportedLocales) {
|
||||
localesJoiner.add(locale.toLanguageTag());
|
||||
}
|
||||
return String.format(Locale.US, "ModelFile { path=%s name=%s version=%d locales=%s }",
|
||||
mPath, mName, mVersion, localesJoiner.toString());
|
||||
}
|
||||
|
||||
private ModelFile(String path, String name, int version, List<Locale> supportedLocales) {
|
||||
mPath = path;
|
||||
mName = name;
|
||||
mVersion = version;
|
||||
mSupportedLocales = supportedLocales;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates intents based on the classification type.
|
||||
*/
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
package android.view.textclassifier.logging;
|
||||
|
||||
import android.annotation.NonNull;
|
||||
import android.annotation.Nullable;
|
||||
import android.content.Context;
|
||||
import android.metrics.LogMaker;
|
||||
import android.util.Log;
|
||||
@@ -27,8 +26,10 @@ import com.android.internal.logging.MetricsLogger;
|
||||
import com.android.internal.logging.nano.MetricsProto.MetricsEvent;
|
||||
import com.android.internal.util.Preconditions;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.StringJoiner;
|
||||
|
||||
/**
|
||||
* Default Logger.
|
||||
@@ -210,12 +211,16 @@ public final class DefaultLogger extends Logger {
|
||||
*/
|
||||
public static String createSignature(
|
||||
String text, int start, int end, Context context, int modelVersion,
|
||||
@Nullable Locale locale) {
|
||||
List<Locale> locales) {
|
||||
Preconditions.checkNotNull(text);
|
||||
Preconditions.checkNotNull(context);
|
||||
final String modelName = (locale != null)
|
||||
? String.format(Locale.US, "%s_v%d", locale.toLanguageTag(), modelVersion)
|
||||
: "";
|
||||
Preconditions.checkNotNull(locales);
|
||||
final StringJoiner localesJoiner = new StringJoiner(",");
|
||||
for (Locale locale : locales) {
|
||||
localesJoiner.add(locale.toLanguageTag());
|
||||
}
|
||||
final String modelName = String.format(Locale.US, "%s_v%d", localesJoiner.toString(),
|
||||
modelVersion);
|
||||
final int hash = Objects.hash(text, start, end, context.getPackageName());
|
||||
return SignatureParser.createSignature(CLASSIFIER_ID, modelName, hash);
|
||||
}
|
||||
@@ -242,9 +247,9 @@ public final class DefaultLogger extends Logger {
|
||||
|
||||
static String getModelName(String signature) {
|
||||
Preconditions.checkNotNull(signature);
|
||||
final int start = signature.indexOf("|");
|
||||
final int start = signature.indexOf("|") + 1;
|
||||
final int end = signature.indexOf("|", start);
|
||||
if (start >= 0 && end >= start) {
|
||||
if (start >= 1 && end >= start) {
|
||||
return signature.substring(start, end);
|
||||
}
|
||||
return "";
|
||||
|
||||
Reference in New Issue
Block a user