Merge "Allowing models to support multiple languages"

This commit is contained in:
TreeHugger Robot
2018-02-06 14:51:03 +00:00
committed by Android (Google) Code Review
3 changed files with 156 additions and 123 deletions

View File

@@ -108,9 +108,9 @@ final class SmartSelection {
}
/**
* Returns the language of the model.
* Returns a comma separated list of locales supported by the model as BCP 47 tags.
*/
public static String getLanguage(int fd) {
public static String getLanguages(int fd) {
return nativeGetLanguage(fd);
}

View File

@@ -58,6 +58,7 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -101,11 +102,9 @@ public final class TextClassifierImpl implements TextClassifier {
private final Object mLock = new Object();
@GuardedBy("mLock") // Do not access outside this lock.
private Map<Locale, String> mModelFilePaths;
private List<ModelFile> mAllModelFiles;
@GuardedBy("mLock") // Do not access outside this lock.
private Locale mLocale;
@GuardedBy("mLock") // Do not access outside this lock.
private int mVersion;
private ModelFile mModel;
@GuardedBy("mLock") // Do not access outside this lock.
private SmartSelection mSmartSelection;
@@ -281,18 +280,18 @@ public final class TextClassifierImpl implements TextClassifier {
private SmartSelection getSmartSelection(LocaleList localeList) throws FileNotFoundException {
synchronized (mLock) {
localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
final Locale locale = findBestSupportedLocaleLocked(localeList);
if (locale == null) {
throw new FileNotFoundException("No file for null locale");
final ModelFile bestModel = findBestModelLocked(localeList);
if (bestModel == null) {
throw new FileNotFoundException("No model for " + localeList.toLanguageTags());
}
if (mSmartSelection == null || !Objects.equals(mLocale, locale)) {
if (mSmartSelection == null || !Objects.equals(mModel, bestModel)) {
Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
destroySmartSelectionIfExistsLocked();
final ParcelFileDescriptor fd = getFdLocked(locale);
final int modelFd = fd.getFd();
mVersion = SmartSelection.getVersion(modelFd);
mSmartSelection = new SmartSelection(modelFd);
final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
mSmartSelection = new SmartSelection(fd.getFd());
closeAndLogError(fd);
mLocale = locale;
mModel = bestModel;
}
return mSmartSelection;
}
@@ -300,74 +299,8 @@ public final class TextClassifierImpl implements TextClassifier {
private String getSignature(String text, int start, int end) {
synchronized (mLock) {
return DefaultLogger.createSignature(text, start, end, mContext, mVersion, mLocale);
}
}
@GuardedBy("mLock") // Do not call outside this lock.
private ParcelFileDescriptor getFdLocked(Locale locale) throws FileNotFoundException {
ParcelFileDescriptor updateFd;
int updateVersion = -1;
try {
updateFd = ParcelFileDescriptor.open(
new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY);
if (updateFd != null) {
updateVersion = SmartSelection.getVersion(updateFd.getFd());
}
} catch (FileNotFoundException e) {
updateFd = null;
}
ParcelFileDescriptor factoryFd;
int factoryVersion = -1;
try {
final String factoryModelFilePath = getFactoryModelFilePathsLocked().get(locale);
if (factoryModelFilePath != null) {
factoryFd = ParcelFileDescriptor.open(
new File(factoryModelFilePath), ParcelFileDescriptor.MODE_READ_ONLY);
if (factoryFd != null) {
factoryVersion = SmartSelection.getVersion(factoryFd.getFd());
}
} else {
factoryFd = null;
}
} catch (FileNotFoundException e) {
factoryFd = null;
}
if (updateFd == null) {
if (factoryFd != null) {
return factoryFd;
} else {
throw new FileNotFoundException(
String.format(Locale.US, "No model file found for %s", locale));
}
}
final int updateFdInt = updateFd.getFd();
final boolean localeMatches = Objects.equals(
locale.getLanguage().trim().toLowerCase(),
SmartSelection.getLanguage(updateFdInt).trim().toLowerCase());
if (factoryFd == null) {
if (localeMatches) {
return updateFd;
} else {
closeAndLogError(updateFd);
throw new FileNotFoundException(
String.format(Locale.US, "No model file found for %s", locale));
}
}
if (!localeMatches) {
closeAndLogError(updateFd);
return factoryFd;
}
if (updateVersion > factoryVersion) {
closeAndLogError(factoryFd);
return updateFd;
} else {
closeAndLogError(updateFd);
return factoryFd;
return DefaultLogger.createSignature(text, start, end, mContext, mModel.getVersion(),
mModel.getSupportedLocales());
}
}
@@ -379,60 +312,66 @@ public final class TextClassifierImpl implements TextClassifier {
}
}
/**
* Finds the most appropriate model to use for the given target locale list.
*
* The basic logic is: we ignore all models that don't support any of the target locales. For
* the remaining candidates, we take the update model unless its version number is lower than
* the factory version. It's assumed that factory models do not have overlapping locale ranges
* and conflict resolution between these models hence doesn't matter.
*/
@GuardedBy("mLock") // Do not call outside this lock.
@Nullable
private Locale findBestSupportedLocaleLocked(LocaleList localeList) {
private ModelFile findBestModelLocked(LocaleList localeList) {
// Specified localeList takes priority over the system default, so it is listed first.
final String languages = localeList.isEmpty()
? LocaleList.getDefault().toLanguageTags()
: localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags();
final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
final List<Locale> supportedLocales =
new ArrayList<>(getFactoryModelFilePathsLocked().keySet());
final Locale updatedModelLocale = getUpdatedModelLocale();
if (updatedModelLocale != null) {
supportedLocales.add(updatedModelLocale);
ModelFile bestModel = null;
int bestModelVersion = -1;
for (ModelFile model : listAllModelsLocked()) {
if (model.isAnyLanguageSupported(languageRangeList)) {
if (model.getVersion() >= bestModelVersion) {
bestModel = model;
bestModelVersion = model.getVersion();
}
}
}
return Locale.lookup(languageRangeList, supportedLocales);
return bestModel;
}
/** Returns a list of all model files available, in order of precedence. */
@GuardedBy("mLock") // Do not call outside this lock.
private Map<Locale, String> getFactoryModelFilePathsLocked() {
if (mModelFilePaths == null) {
final Map<Locale, String> modelFilePaths = new HashMap<>();
private List<ModelFile> listAllModelsLocked() {
if (mAllModelFiles == null) {
final List<ModelFile> allModels = new ArrayList<>();
// The update model has the highest precedence.
if (new File(UPDATED_MODEL_FILE_PATH).exists()) {
final ModelFile updatedModel = ModelFile.fromPath(UPDATED_MODEL_FILE_PATH);
if (updatedModel != null) {
allModels.add(updatedModel);
}
}
// Factory models should never have overlapping locales, so the order doesn't matter.
final File modelsDir = new File(MODEL_DIR);
if (modelsDir.exists() && modelsDir.isDirectory()) {
final File[] models = modelsDir.listFiles();
final File[] modelFiles = modelsDir.listFiles();
final Pattern modelFilenamePattern = Pattern.compile(MODEL_FILE_REGEX);
final int size = models.length;
for (int i = 0; i < size; i++) {
final File modelFile = models[i];
for (File modelFile : modelFiles) {
final Matcher matcher = modelFilenamePattern.matcher(modelFile.getName());
if (matcher.matches() && modelFile.isFile()) {
final String language = matcher.group(1);
final Locale locale = Locale.forLanguageTag(language);
modelFilePaths.put(locale, modelFile.getAbsolutePath());
final ModelFile model = ModelFile.fromPath(modelFile.getAbsolutePath());
if (model != null) {
allModels.add(model);
}
}
}
}
mModelFilePaths = modelFilePaths;
}
return mModelFilePaths;
}
@Nullable
private Locale getUpdatedModelLocale() {
try {
final ParcelFileDescriptor updateFd = ParcelFileDescriptor.open(
new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY);
final Locale locale = Locale.forLanguageTag(
SmartSelection.getLanguage(updateFd.getFd()));
closeAndLogError(updateFd);
return locale;
} catch (FileNotFoundException e) {
return null;
mAllModelFiles = allModels;
}
return mAllModelFiles;
}
private TextClassification createClassificationResult(
@@ -521,6 +460,95 @@ public final class TextClassifierImpl implements TextClassifier {
}
}
/**
* Describes TextClassifier model files on disk.
*/
private static final class ModelFile {
private final String mPath;
private final String mName;
private final int mVersion;
private final List<Locale> mSupportedLocales;
/** Returns null if the path did not point to a compatible model. */
static @Nullable ModelFile fromPath(String path) {
final File file = new File(path);
try {
final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
file, ParcelFileDescriptor.MODE_READ_ONLY);
final int version = SmartSelection.getVersion(modelFd.getFd());
final String supportedLocalesStr = SmartSelection.getLanguages(modelFd.getFd());
if (supportedLocalesStr.isEmpty()) {
Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
return null;
}
final List<Locale> supportedLocales = new ArrayList<>();
for (String langTag : supportedLocalesStr.split(",")) {
supportedLocales.add(Locale.forLanguageTag(langTag));
}
closeAndLogError(modelFd);
return new ModelFile(path, file.getName(), version, supportedLocales);
} catch (FileNotFoundException e) {
Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e);
return null;
}
}
/** The absolute path to the model file. */
String getPath() {
return mPath;
}
/** A name to use for signature generation. Effectively the name of the model file. */
String getName() {
return mName;
}
/** Returns the version tag in the model's metadata. */
int getVersion() {
return mVersion;
}
/** Returns whether the language supports any language in the given ranges. */
boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
return Locale.lookup(languageRanges, mSupportedLocales) != null;
}
/** All locales supported by the model. */
List<Locale> getSupportedLocales() {
return Collections.unmodifiableList(mSupportedLocales);
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) {
return false;
} else {
final ModelFile otherModel = (ModelFile) other;
return mPath.equals(otherModel.mPath);
}
}
@Override
public String toString() {
final StringJoiner localesJoiner = new StringJoiner(",");
for (Locale locale : mSupportedLocales) {
localesJoiner.add(locale.toLanguageTag());
}
return String.format(Locale.US, "ModelFile { path=%s name=%s version=%d locales=%s }",
mPath, mName, mVersion, localesJoiner.toString());
}
private ModelFile(String path, String name, int version, List<Locale> supportedLocales) {
mPath = path;
mName = name;
mVersion = version;
mSupportedLocales = supportedLocales;
}
}
/**
* Creates intents based on the classification type.
*/

View File

@@ -17,7 +17,6 @@
package android.view.textclassifier.logging;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.content.Context;
import android.metrics.LogMaker;
import android.util.Log;
@@ -27,8 +26,10 @@ import com.android.internal.logging.MetricsLogger;
import com.android.internal.logging.nano.MetricsProto.MetricsEvent;
import com.android.internal.util.Preconditions;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.StringJoiner;
/**
* Default Logger.
@@ -210,12 +211,16 @@ public final class DefaultLogger extends Logger {
*/
public static String createSignature(
String text, int start, int end, Context context, int modelVersion,
@Nullable Locale locale) {
List<Locale> locales) {
Preconditions.checkNotNull(text);
Preconditions.checkNotNull(context);
final String modelName = (locale != null)
? String.format(Locale.US, "%s_v%d", locale.toLanguageTag(), modelVersion)
: "";
Preconditions.checkNotNull(locales);
final StringJoiner localesJoiner = new StringJoiner(",");
for (Locale locale : locales) {
localesJoiner.add(locale.toLanguageTag());
}
final String modelName = String.format(Locale.US, "%s_v%d", localesJoiner.toString(),
modelVersion);
final int hash = Objects.hash(text, start, end, context.getPackageName());
return SignatureParser.createSignature(CLASSIFIER_ID, modelName, hash);
}
@@ -242,9 +247,9 @@ public final class DefaultLogger extends Logger {
static String getModelName(String signature) {
Preconditions.checkNotNull(signature);
final int start = signature.indexOf("|");
final int start = signature.indexOf("|") + 1;
final int end = signature.indexOf("|", start);
if (start >= 0 && end >= start) {
if (start >= 1 && end >= start) {
return signature.substring(start, end);
}
return "";