Merge "Switch TextClassifier implementation from native to java"

This commit is contained in:
Nikita Iashchenko
2019-01-17 15:34:02 +00:00
committed by Gerrit Code Review
3 changed files with 25 additions and 321 deletions

View File

@@ -704,7 +704,6 @@ java_defaults {
required: [
// TODO: remove gps_debug when the build system propagates "required" properly.
"gps_debug.conf",
"libtextclassifier",
// Loaded with System.loadLibrary by android.view.textclassifier
"libmedia2_jni",
],
@@ -855,6 +854,10 @@ java_library {
"nist-sip",
"tagsoup",
"rappor",
"libtextclassifier-java",
],
required: [
"libtextclassifier",
],
dxflags: ["--core-library"],
}

View File

@@ -43,6 +43,8 @@ import android.provider.ContactsContract;
import com.android.internal.annotations.GuardedBy;
import com.android.internal.util.Preconditions;
import com.google.android.textclassifier.AnnotatorModel;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
@@ -91,7 +93,7 @@ public final class TextClassifierImpl implements TextClassifier {
@GuardedBy("mLock") // Do not access outside this lock.
private ModelFile mModel;
@GuardedBy("mLock") // Do not access outside this lock.
private TextClassifierImplNative mNative;
private AnnotatorModel mNative;
private final Object mLoggerLock = new Object();
@GuardedBy("mLoggerLock") // Do not access outside this lock.
@@ -124,7 +126,7 @@ public final class TextClassifierImpl implements TextClassifier {
&& rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
final String localesString = concatenateLocales(request.getDefaultLocales());
final ZonedDateTime refTime = ZonedDateTime.now();
final TextClassifierImplNative nativeImpl = getNative(request.getDefaultLocales());
final AnnotatorModel nativeImpl = getNative(request.getDefaultLocales());
final int start;
final int end;
if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) {
@@ -133,7 +135,7 @@ public final class TextClassifierImpl implements TextClassifier {
} else {
final int[] startEnd = nativeImpl.suggestSelection(
string, request.getStartIndex(), request.getEndIndex(),
new TextClassifierImplNative.SelectionOptions(localesString));
new AnnotatorModel.SelectionOptions(localesString));
start = startEnd[0];
end = startEnd[1];
}
@@ -141,10 +143,10 @@ public final class TextClassifierImpl implements TextClassifier {
&& start >= 0 && end <= string.length()
&& start <= request.getStartIndex() && end >= request.getEndIndex()) {
final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
final TextClassifierImplNative.ClassificationResult[] results =
final AnnotatorModel.ClassificationResult[] results =
nativeImpl.classifyText(
string, start, end,
new TextClassifierImplNative.ClassificationOptions(
new AnnotatorModel.ClassificationOptions(
refTime.toInstant().toEpochMilli(),
refTime.getZone().getId(),
localesString));
@@ -183,11 +185,11 @@ public final class TextClassifierImpl implements TextClassifier {
final String localesString = concatenateLocales(request.getDefaultLocales());
final ZonedDateTime refTime = request.getReferenceTime() != null
? request.getReferenceTime() : ZonedDateTime.now();
final TextClassifierImplNative.ClassificationResult[] results =
final AnnotatorModel.ClassificationResult[] results =
getNative(request.getDefaultLocales())
.classifyText(
string, request.getStartIndex(), request.getEndIndex(),
new TextClassifierImplNative.ClassificationOptions(
new AnnotatorModel.ClassificationOptions(
refTime.toInstant().toEpochMilli(),
refTime.getZone().getId(),
localesString));
@@ -227,17 +229,17 @@ public final class TextClassifierImpl implements TextClassifier {
? request.getEntityConfig().resolveEntityListModifications(
getEntitiesForHints(request.getEntityConfig().getHints()))
: mSettings.getEntityListDefault();
final TextClassifierImplNative nativeImpl =
final AnnotatorModel nativeImpl =
getNative(request.getDefaultLocales());
final TextClassifierImplNative.AnnotatedSpan[] annotations =
final AnnotatorModel.AnnotatedSpan[] annotations =
nativeImpl.annotate(
textString,
new TextClassifierImplNative.AnnotationOptions(
new AnnotatorModel.AnnotationOptions(
refTime.toInstant().toEpochMilli(),
refTime.getZone().getId(),
concatenateLocales(request.getDefaultLocales())));
for (TextClassifierImplNative.AnnotatedSpan span : annotations) {
final TextClassifierImplNative.ClassificationResult[] results =
for (AnnotatorModel.AnnotatedSpan span : annotations) {
final AnnotatorModel.ClassificationResult[] results =
span.getClassification();
if (results.length == 0
|| !entitiesToIdentify.contains(results[0].getCollection())) {
@@ -296,7 +298,7 @@ public final class TextClassifierImpl implements TextClassifier {
}
}
private TextClassifierImplNative getNative(LocaleList localeList)
private AnnotatorModel getNative(LocaleList localeList)
throws FileNotFoundException {
synchronized (mLock) {
localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
@@ -309,7 +311,7 @@ public final class TextClassifierImpl implements TextClassifier {
destroyNativeIfExistsLocked();
final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
mNative = new TextClassifierImplNative(fd.getFd());
mNative = new AnnotatorModel(fd.getFd());
closeAndLogError(fd);
mModel = bestModel;
}
@@ -397,14 +399,14 @@ public final class TextClassifierImpl implements TextClassifier {
}
private TextClassification createClassificationResult(
TextClassifierImplNative.ClassificationResult[] classifications,
AnnotatorModel.ClassificationResult[] classifications,
String text, int start, int end, @Nullable Instant referenceTime) {
final String classifiedText = text.substring(start, end);
final TextClassification.Builder builder = new TextClassification.Builder()
.setText(classifiedText);
final int size = classifications.length;
TextClassifierImplNative.ClassificationResult highestScoringResult = null;
AnnotatorModel.ClassificationResult highestScoringResult = null;
float highestScore = Float.MIN_VALUE;
for (int i = 0; i < size; i++) {
builder.setEntityType(classifications[i].getCollection(),
@@ -467,9 +469,9 @@ public final class TextClassifierImpl implements TextClassifier {
try {
final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
file, ParcelFileDescriptor.MODE_READ_ONLY);
final int version = TextClassifierImplNative.getVersion(modelFd.getFd());
final int version = AnnotatorModel.getVersion(modelFd.getFd());
final String supportedLocalesStr =
TextClassifierImplNative.getLocales(modelFd.getFd());
AnnotatorModel.getLocales(modelFd.getFd());
if (supportedLocalesStr.isEmpty()) {
Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
return null;
@@ -657,7 +659,7 @@ public final class TextClassifierImpl implements TextClassifier {
public static List<LabeledIntent> create(
Context context,
@Nullable Instant referenceTime,
TextClassifierImplNative.ClassificationResult classification,
AnnotatorModel.ClassificationResult classification,
String text) {
final String type = classification.getCollection().trim().toLowerCase(Locale.ENGLISH);
text = text.trim();

View File

@@ -1,301 +0,0 @@
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.view.textclassifier;
import android.content.res.AssetFileDescriptor;
/**
* Java wrapper for TextClassifier native library interface. This library is used for detecting
* entities in text.
*/
final class TextClassifierImplNative {
static {
System.loadLibrary("textclassifier");
}
private final long mModelPtr;
/**
* Creates a new instance of TextClassifierImplNative, using the provided model image, given as
* a file descriptor.
*/
TextClassifierImplNative(int fd) {
mModelPtr = nativeNew(fd);
if (mModelPtr == 0L) {
throw new IllegalArgumentException("Couldn't initialize TC from file descriptor.");
}
}
/**
* Creates a new instance of TextClassifierImplNative, using the provided model image, given as
* a file path.
*/
TextClassifierImplNative(String path) {
mModelPtr = nativeNewFromPath(path);
if (mModelPtr == 0L) {
throw new IllegalArgumentException("Couldn't initialize TC from given file.");
}
}
/**
* Creates a new instance of TextClassifierImplNative, using the provided model image, given as
* an AssetFileDescriptor.
*/
TextClassifierImplNative(AssetFileDescriptor afd) {
mModelPtr = nativeNewFromAssetFileDescriptor(afd, afd.getStartOffset(), afd.getLength());
if (mModelPtr == 0L) {
throw new IllegalArgumentException(
"Couldn't initialize TC from given AssetFileDescriptor");
}
}
/**
* Given a string context and current selection, computes the SmartSelection suggestion.
*
* <p>The begin and end are character indices into the context UTF8 string. selectionBegin is
* the character index where the selection begins, and selectionEnd is the index of one
* character past the selection span.
*
* <p>The return value is an array of two ints: suggested selection beginning and end, with the
* same semantics as the input selectionBeginning and selectionEnd.
*/
public int[] suggestSelection(
String context, int selectionBegin, int selectionEnd, SelectionOptions options) {
return nativeSuggestSelection(mModelPtr, context, selectionBegin, selectionEnd, options);
}
/**
* Given a string context and current selection, classifies the type of the selected text.
*
* <p>The begin and end params are character indices in the context string.
*
* <p>Returns an array of ClassificationResult objects with the probability scores for different
* collections.
*/
public ClassificationResult[] classifyText(
String context, int selectionBegin, int selectionEnd, ClassificationOptions options) {
return nativeClassifyText(mModelPtr, context, selectionBegin, selectionEnd, options);
}
/**
* Annotates given input text. The annotations should cover the whole input context except for
* whitespaces, and are sorted by their position in the context string.
*/
public AnnotatedSpan[] annotate(String text, AnnotationOptions options) {
return nativeAnnotate(mModelPtr, text, options);
}
/** Frees up the allocated memory. */
public void close() {
nativeClose(mModelPtr);
}
/** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
public static String getLocales(int fd) {
return nativeGetLocales(fd);
}
/** Returns the version of the model. */
public static int getVersion(int fd) {
return nativeGetVersion(fd);
}
/** Represents a datetime parsing result from classifyText calls. */
public static final class DatetimeResult {
static final int GRANULARITY_YEAR = 0;
static final int GRANULARITY_MONTH = 1;
static final int GRANULARITY_WEEK = 2;
static final int GRANULARITY_DAY = 3;
static final int GRANULARITY_HOUR = 4;
static final int GRANULARITY_MINUTE = 5;
static final int GRANULARITY_SECOND = 6;
private final long mTimeMsUtc;
private final int mGranularity;
DatetimeResult(long timeMsUtc, int granularity) {
mGranularity = granularity;
mTimeMsUtc = timeMsUtc;
}
public long getTimeMsUtc() {
return mTimeMsUtc;
}
public int getGranularity() {
return mGranularity;
}
}
/** Represents a result of classifyText method call. */
public static final class ClassificationResult {
private final String mCollection;
private final float mScore;
private final DatetimeResult mDatetimeResult;
ClassificationResult(
String collection, float score, DatetimeResult datetimeResult) {
mCollection = collection;
mScore = score;
mDatetimeResult = datetimeResult;
}
public String getCollection() {
if (mCollection.equals(TextClassifier.TYPE_DATE) && mDatetimeResult != null) {
switch (mDatetimeResult.getGranularity()) {
case DatetimeResult.GRANULARITY_HOUR:
// fall through
case DatetimeResult.GRANULARITY_MINUTE:
// fall through
case DatetimeResult.GRANULARITY_SECOND:
return TextClassifier.TYPE_DATE_TIME;
default:
return TextClassifier.TYPE_DATE;
}
}
return mCollection;
}
public float getScore() {
return mScore;
}
public DatetimeResult getDatetimeResult() {
return mDatetimeResult;
}
}
/** Represents a result of Annotate call. */
public static final class AnnotatedSpan {
private final int mStartIndex;
private final int mEndIndex;
private final ClassificationResult[] mClassification;
AnnotatedSpan(
int startIndex, int endIndex, ClassificationResult[] classification) {
mStartIndex = startIndex;
mEndIndex = endIndex;
mClassification = classification;
}
public int getStartIndex() {
return mStartIndex;
}
public int getEndIndex() {
return mEndIndex;
}
public ClassificationResult[] getClassification() {
return mClassification;
}
}
/** Represents options for the suggestSelection call. */
public static final class SelectionOptions {
private final String mLocales;
SelectionOptions(String locales) {
mLocales = locales;
}
public String getLocales() {
return mLocales;
}
}
/** Represents options for the classifyText call. */
public static final class ClassificationOptions {
private final long mReferenceTimeMsUtc;
private final String mReferenceTimezone;
private final String mLocales;
ClassificationOptions(long referenceTimeMsUtc, String referenceTimezone, String locale) {
mReferenceTimeMsUtc = referenceTimeMsUtc;
mReferenceTimezone = referenceTimezone;
mLocales = locale;
}
public long getReferenceTimeMsUtc() {
return mReferenceTimeMsUtc;
}
public String getReferenceTimezone() {
return mReferenceTimezone;
}
public String getLocale() {
return mLocales;
}
}
/** Represents options for the Annotate call. */
public static final class AnnotationOptions {
private final long mReferenceTimeMsUtc;
private final String mReferenceTimezone;
private final String mLocales;
AnnotationOptions(long referenceTimeMsUtc, String referenceTimezone, String locale) {
mReferenceTimeMsUtc = referenceTimeMsUtc;
mReferenceTimezone = referenceTimezone;
mLocales = locale;
}
public long getReferenceTimeMsUtc() {
return mReferenceTimeMsUtc;
}
public String getReferenceTimezone() {
return mReferenceTimezone;
}
public String getLocale() {
return mLocales;
}
}
private static native long nativeNew(int fd);
private static native long nativeNewFromPath(String path);
private static native long nativeNewFromAssetFileDescriptor(
AssetFileDescriptor afd, long offset, long size);
private static native int[] nativeSuggestSelection(
long context,
String text,
int selectionBegin,
int selectionEnd,
SelectionOptions options);
private static native ClassificationResult[] nativeClassifyText(
long context,
String text,
int selectionBegin,
int selectionEnd,
ClassificationOptions options);
private static native AnnotatedSpan[] nativeAnnotate(
long context, String text, AnnotationOptions options);
private static native void nativeClose(long context);
private static native String nativeGetLocales(int fd);
private static native int nativeGetVersion(int fd);
}