RESTRICT AUTOMERGE TextClassifier cross-user vulnerability in direct-reply

Sys UI runs on user 0. This can lead to the TextClassifier (TC)
running for the wrong user. Consequencies are user A can launch apps
in user B via the TC's predicted actions and selected text being
unintentionally shared from user A to an app running in user B.

This fix ensures that the correct user id is passed and verified for
every TC request going across process boundaries (i.e. via SystemTC).
- Sys UI sets the appropriate user id in the TextView
- TextClassificationManager (TCM) system service is constructed using
  a context generated from this user id
- SystemTC sets this user id before querying the TCMService
- TCMService validates the user id before forwarding the request to
  the TCService belonging to that user id.

Bug: 136483597
Bug: 123232892
Test: atest android.view.textclassifier
      atest android.widget.TextViewActivityTest
      (manual) See I2fdffd8eb4221782cb1f34d2ddbe41dd3d36595c

Change-Id: Ibe68bc9e257521de97cbb014176b2b8ba23547d1
This commit is contained in:
Abodunrinwa Toki
2019-07-01 19:41:44 +01:00
parent 8931739c16
commit 34e380cdd6
12 changed files with 289 additions and 50 deletions

View File

@@ -60,7 +60,9 @@ public final class ActionsModelParamsSupplier implements
private boolean mParsed = true;
public ActionsModelParamsSupplier(Context context, @Nullable Runnable onChangedListener) {
mAppContext = Preconditions.checkNotNull(context).getApplicationContext();
final Context appContext = Preconditions.checkNotNull(context).getApplicationContext();
// Some contexts don't have an app context.
mAppContext = appContext != null ? appContext : context;
mOnChangedListener = onChangedListener == null ? () -> {} : onChangedListener;
mSettingsObserver = new SettingsObserver(mAppContext, () -> {
synchronized (mLock) {

View File

@@ -21,10 +21,12 @@ import android.annotation.IntRange;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.StringDef;
import android.annotation.UserIdInt;
import android.app.Person;
import android.os.Bundle;
import android.os.Parcel;
import android.os.Parcelable;
import android.os.UserHandle;
import android.text.SpannedString;
import com.android.internal.annotations.VisibleForTesting;
@@ -316,6 +318,8 @@ public final class ConversationActions implements Parcelable {
private final List<String> mHints;
@Nullable
private String mCallingPackageName;
@UserIdInt
private int mUserId = UserHandle.USER_NULL;
@NonNull
private Bundle mExtras;
@@ -340,6 +344,7 @@ public final class ConversationActions implements Parcelable {
List<String> hints = new ArrayList<>();
in.readStringList(hints);
String callingPackageName = in.readString();
int userId = in.readInt();
Bundle extras = in.readBundle();
Request request = new Request(
conversation,
@@ -348,6 +353,7 @@ public final class ConversationActions implements Parcelable {
hints,
extras);
request.setCallingPackageName(callingPackageName);
request.setUserId(userId);
return request;
}
@@ -358,6 +364,7 @@ public final class ConversationActions implements Parcelable {
parcel.writeInt(mMaxSuggestions);
parcel.writeStringList(mHints);
parcel.writeString(mCallingPackageName);
parcel.writeInt(mUserId);
parcel.writeBundle(mExtras);
}
@@ -427,6 +434,24 @@ public final class ConversationActions implements Parcelable {
return mCallingPackageName;
}
/**
* Sets the id of the user that sent this request.
* <p>
* Package-private for SystemTextClassifier's use.
*/
void setUserId(@UserIdInt int userId) {
mUserId = userId;
}
/**
* Returns the id of the user that sent this request.
* @hide
*/
@UserIdInt
public int getUserId() {
return mUserId;
}
/**
* Returns the extended data related to this request.
*

View File

@@ -19,8 +19,10 @@ package android.view.textclassifier;
import android.annotation.IntDef;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.UserIdInt;
import android.os.Parcel;
import android.os.Parcelable;
import android.os.UserHandle;
import android.view.textclassifier.TextClassifier.EntityType;
import android.view.textclassifier.TextClassifier.WidgetType;
@@ -127,6 +129,7 @@ public final class SelectionEvent implements Parcelable {
private String mWidgetType = TextClassifier.WIDGET_TYPE_UNKNOWN;
private @InvocationMethod int mInvocationMethod;
@Nullable private String mWidgetVersion;
private @UserIdInt int mUserId = UserHandle.USER_NULL;
@Nullable private String mResultId;
private long mEventTime;
private long mDurationSinceSessionStart;
@@ -158,6 +161,7 @@ public final class SelectionEvent implements Parcelable {
mEntityType = in.readString();
mWidgetVersion = in.readInt() > 0 ? in.readString() : null;
mPackageName = in.readString();
mUserId = in.readInt();
mWidgetType = in.readString();
mInvocationMethod = in.readInt();
mResultId = in.readString();
@@ -184,6 +188,7 @@ public final class SelectionEvent implements Parcelable {
dest.writeString(mWidgetVersion);
}
dest.writeString(mPackageName);
dest.writeInt(mUserId);
dest.writeString(mWidgetType);
dest.writeInt(mInvocationMethod);
dest.writeString(mResultId);
@@ -400,6 +405,24 @@ public final class SelectionEvent implements Parcelable {
return mPackageName;
}
/**
* Sets the id of this event's user.
* <p>
* Package-private for SystemTextClassifier's use.
*/
void setUserId(@UserIdInt int userId) {
mUserId = userId;
}
/**
* Returns the id of this event's user.
* @hide
*/
@UserIdInt
public int getUserId() {
return mUserId;
}
/**
* Returns the type of widget that was involved in triggering this event.
*/
@@ -426,6 +449,7 @@ public final class SelectionEvent implements Parcelable {
mPackageName = context.getPackageName();
mWidgetType = context.getWidgetType();
mWidgetVersion = context.getWidgetVersion();
mUserId = context.getUserId();
}
/**
@@ -612,7 +636,7 @@ public final class SelectionEvent implements Parcelable {
@Override
public int hashCode() {
return Objects.hash(mAbsoluteStart, mAbsoluteEnd, mEventType, mEntityType,
mWidgetVersion, mPackageName, mWidgetType, mInvocationMethod, mResultId,
mWidgetVersion, mPackageName, mUserId, mWidgetType, mInvocationMethod, mResultId,
mEventTime, mDurationSinceSessionStart, mDurationSincePreviousEvent,
mEventIndex, mSessionId, mStart, mEnd, mSmartStart, mSmartEnd);
}
@@ -633,6 +657,7 @@ public final class SelectionEvent implements Parcelable {
&& Objects.equals(mEntityType, other.mEntityType)
&& Objects.equals(mWidgetVersion, other.mWidgetVersion)
&& Objects.equals(mPackageName, other.mPackageName)
&& mUserId == other.mUserId
&& Objects.equals(mWidgetType, other.mWidgetType)
&& mInvocationMethod == other.mInvocationMethod
&& Objects.equals(mResultId, other.mResultId)
@@ -652,12 +677,12 @@ public final class SelectionEvent implements Parcelable {
return String.format(Locale.US,
"SelectionEvent {absoluteStart=%d, absoluteEnd=%d, eventType=%d, entityType=%s, "
+ "widgetVersion=%s, packageName=%s, widgetType=%s, invocationMethod=%s, "
+ "resultId=%s, eventTime=%d, durationSinceSessionStart=%d, "
+ "userId=%d, resultId=%s, eventTime=%d, durationSinceSessionStart=%d, "
+ "durationSincePreviousEvent=%d, eventIndex=%d,"
+ "sessionId=%s, start=%d, end=%d, smartStart=%d, smartEnd=%d}",
mAbsoluteStart, mAbsoluteEnd, mEventType, mEntityType,
mWidgetVersion, mPackageName, mWidgetType, mInvocationMethod,
mResultId, mEventTime, mDurationSinceSessionStart,
mUserId, mResultId, mEventTime, mDurationSinceSessionStart,
mDurationSincePreviousEvent, mEventIndex,
mSessionId, mStart, mEnd, mSmartStart, mSmartEnd);
}

View File

@@ -18,6 +18,7 @@ package android.view.textclassifier;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.UserIdInt;
import android.annotation.WorkerThread;
import android.content.Context;
import android.os.Bundle;
@@ -50,6 +51,10 @@ public final class SystemTextClassifier implements TextClassifier {
private final TextClassificationConstants mSettings;
private final TextClassifier mFallback;
private final String mPackageName;
// NOTE: Always set this before sending a request to the manager service otherwise the manager
// service will throw a remote exception.
@UserIdInt
private final int mUserId;
private TextClassificationSessionId mSessionId;
public SystemTextClassifier(Context context, TextClassificationConstants settings)
@@ -60,6 +65,7 @@ public final class SystemTextClassifier implements TextClassifier {
mFallback = context.getSystemService(TextClassificationManager.class)
.getTextClassifier(TextClassifier.LOCAL);
mPackageName = Preconditions.checkNotNull(context.getOpPackageName());
mUserId = context.getUserId();
}
/**
@@ -72,6 +78,7 @@ public final class SystemTextClassifier implements TextClassifier {
Utils.checkMainThread();
try {
request.setCallingPackageName(mPackageName);
request.setUserId(mUserId);
final BlockingCallback<TextSelection> callback =
new BlockingCallback<>("textselection");
mManagerService.onSuggestSelection(mSessionId, request, callback);
@@ -95,6 +102,7 @@ public final class SystemTextClassifier implements TextClassifier {
Utils.checkMainThread();
try {
request.setCallingPackageName(mPackageName);
request.setUserId(mUserId);
final BlockingCallback<TextClassification> callback =
new BlockingCallback<>("textclassification");
mManagerService.onClassifyText(mSessionId, request, callback);
@@ -123,6 +131,7 @@ public final class SystemTextClassifier implements TextClassifier {
try {
request.setCallingPackageName(mPackageName);
request.setUserId(mUserId);
final BlockingCallback<TextLinks> callback =
new BlockingCallback<>("textlinks");
mManagerService.onGenerateLinks(mSessionId, request, callback);
@@ -142,6 +151,7 @@ public final class SystemTextClassifier implements TextClassifier {
Utils.checkMainThread();
try {
event.setUserId(mUserId);
mManagerService.onSelectionEvent(mSessionId, event);
} catch (RemoteException e) {
Log.e(LOG_TAG, "Error reporting selection event.", e);
@@ -154,6 +164,12 @@ public final class SystemTextClassifier implements TextClassifier {
Utils.checkMainThread();
try {
final TextClassificationContext tcContext = event.getEventContext() == null
? new TextClassificationContext.Builder(mPackageName, WIDGET_TYPE_UNKNOWN)
.build()
: event.getEventContext();
tcContext.setUserId(mUserId);
event.setEventContext(tcContext);
mManagerService.onTextClassifierEvent(mSessionId, event);
} catch (RemoteException e) {
Log.e(LOG_TAG, "Error reporting textclassifier event.", e);
@@ -167,6 +183,7 @@ public final class SystemTextClassifier implements TextClassifier {
try {
request.setCallingPackageName(mPackageName);
request.setUserId(mUserId);
final BlockingCallback<TextLanguage> callback =
new BlockingCallback<>("textlanguage");
mManagerService.onDetectLanguage(mSessionId, request, callback);
@@ -187,6 +204,7 @@ public final class SystemTextClassifier implements TextClassifier {
try {
request.setCallingPackageName(mPackageName);
request.setUserId(mUserId);
final BlockingCallback<ConversationActions> callback =
new BlockingCallback<>("conversation-actions");
mManagerService.onSuggestConversationActions(mSessionId, request, callback);
@@ -228,6 +246,7 @@ public final class SystemTextClassifier implements TextClassifier {
printWriter.printPair("mFallback", mFallback);
printWriter.printPair("mPackageName", mPackageName);
printWriter.printPair("mSessionId", mSessionId);
printWriter.printPair("mUserId", mUserId);
printWriter.decreaseIndent();
printWriter.println();
}
@@ -243,6 +262,7 @@ public final class SystemTextClassifier implements TextClassifier {
@NonNull TextClassificationSessionId sessionId) {
mSessionId = Preconditions.checkNotNull(sessionId);
try {
classificationContext.setUserId(mUserId);
mManagerService.onCreateTextClassificationSession(classificationContext, mSessionId);
} catch (RemoteException e) {
Log.e(LOG_TAG, "Error starting a new classification session.", e);

View File

@@ -21,6 +21,7 @@ import android.annotation.IntDef;
import android.annotation.IntRange;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.UserIdInt;
import android.app.PendingIntent;
import android.app.RemoteAction;
import android.content.Context;
@@ -35,6 +36,7 @@ import android.os.Bundle;
import android.os.LocaleList;
import android.os.Parcel;
import android.os.Parcelable;
import android.os.UserHandle;
import android.text.SpannedString;
import android.util.ArrayMap;
import android.view.View.OnClickListener;
@@ -551,6 +553,8 @@ public final class TextClassification implements Parcelable {
@Nullable private final ZonedDateTime mReferenceTime;
@NonNull private final Bundle mExtras;
@Nullable private String mCallingPackageName;
@UserIdInt
private int mUserId = UserHandle.USER_NULL;
private Request(
CharSequence text,
@@ -630,6 +634,24 @@ public final class TextClassification implements Parcelable {
return mCallingPackageName;
}
/**
* Sets the id of the user that sent this request.
* <p>
* Package-private for SystemTextClassifier's use.
*/
void setUserId(@UserIdInt int userId) {
mUserId = userId;
}
/**
* Returns the id of the user that sent this request.
* @hide
*/
@UserIdInt
public int getUserId() {
return mUserId;
}
/**
* Returns the extended data.
*
@@ -730,6 +752,7 @@ public final class TextClassification implements Parcelable {
dest.writeParcelable(mDefaultLocales, flags);
dest.writeString(mReferenceTime == null ? null : mReferenceTime.toString());
dest.writeString(mCallingPackageName);
dest.writeInt(mUserId);
dest.writeBundle(mExtras);
}
@@ -742,11 +765,13 @@ public final class TextClassification implements Parcelable {
final ZonedDateTime referenceTime = referenceTimeString == null
? null : ZonedDateTime.parse(referenceTimeString);
final String callingPackageName = in.readString();
final int userId = in.readInt();
final Bundle extras = in.readBundle();
final Request request = new Request(text, startIndex, endIndex,
defaultLocales, referenceTime, extras);
request.setCallingPackageName(callingPackageName);
request.setUserId(userId);
return request;
}

View File

@@ -18,8 +18,10 @@ package android.view.textclassifier;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.UserIdInt;
import android.os.Parcel;
import android.os.Parcelable;
import android.os.UserHandle;
import android.view.textclassifier.TextClassifier.WidgetType;
import com.android.internal.util.Preconditions;
@@ -35,6 +37,8 @@ public final class TextClassificationContext implements Parcelable {
private final String mPackageName;
private final String mWidgetType;
@Nullable private final String mWidgetVersion;
@UserIdInt
private int mUserId = UserHandle.USER_NULL;
private TextClassificationContext(
String packageName,
@@ -53,6 +57,24 @@ public final class TextClassificationContext implements Parcelable {
return mPackageName;
}
/**
* Sets the id of this context's user.
* <p>
* Package-private for SystemTextClassifier's use.
*/
void setUserId(@UserIdInt int userId) {
mUserId = userId;
}
/**
* Returns the id of this context's user.
* @hide
*/
@UserIdInt
public int getUserId() {
return mUserId;
}
/**
* Returns the widget type for this classification context.
*/
@@ -75,8 +97,8 @@ public final class TextClassificationContext implements Parcelable {
@Override
public String toString() {
return String.format(Locale.US, "TextClassificationContext{"
+ "packageName=%s, widgetType=%s, widgetVersion=%s}",
mPackageName, mWidgetType, mWidgetVersion);
+ "packageName=%s, widgetType=%s, widgetVersion=%s, userId=%d}",
mPackageName, mWidgetType, mWidgetVersion, mUserId);
}
/**
@@ -133,12 +155,14 @@ public final class TextClassificationContext implements Parcelable {
parcel.writeString(mPackageName);
parcel.writeString(mWidgetType);
parcel.writeString(mWidgetVersion);
parcel.writeInt(mUserId);
}
private TextClassificationContext(Parcel in) {
mPackageName = in.readString();
mWidgetType = in.readString();
mWidgetVersion = in.readString();
mUserId = in.readInt();
}
public static final @android.annotation.NonNull Parcelable.Creator<TextClassificationContext> CREATOR =

View File

@@ -139,7 +139,7 @@ public abstract class TextClassifierEvent implements Parcelable {
@Nullable
private final String[] mEntityTypes;
@Nullable
private final TextClassificationContext mEventContext;
private TextClassificationContext mEventContext;
@Nullable
private final String mResultId;
private final int mEventIndex;
@@ -288,6 +288,15 @@ public abstract class TextClassifierEvent implements Parcelable {
return mEventContext;
}
/**
* Sets the event context.
* <p>
* Package-private for SystemTextClassifier's use.
*/
void setEventContext(@Nullable TextClassificationContext eventContext) {
mEventContext = eventContext;
}
/**
* Returns the id of the text classifier result related to this event.
*/

View File

@@ -20,10 +20,12 @@ import android.annotation.FloatRange;
import android.annotation.IntRange;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.UserIdInt;
import android.icu.util.ULocale;
import android.os.Bundle;
import android.os.Parcel;
import android.os.Parcelable;
import android.os.UserHandle;
import android.util.ArrayMap;
import com.android.internal.annotations.VisibleForTesting;
@@ -226,6 +228,8 @@ public final class TextLanguage implements Parcelable {
private final CharSequence mText;
private final Bundle mExtra;
@Nullable private String mCallingPackageName;
@UserIdInt
private int mUserId = UserHandle.USER_NULL;
private Request(CharSequence text, Bundle bundle) {
mText = text;
@@ -259,6 +263,24 @@ public final class TextLanguage implements Parcelable {
return mCallingPackageName;
}
/**
* Sets the id of the user that sent this request.
* <p>
* Package-private for SystemTextClassifier's use.
*/
void setUserId(@UserIdInt int userId) {
mUserId = userId;
}
/**
* Returns the id of the user that sent this request.
* @hide
*/
@UserIdInt
public int getUserId() {
return mUserId;
}
/**
* Returns a bundle containing non-structured extra information about this request.
*
@@ -278,16 +300,19 @@ public final class TextLanguage implements Parcelable {
public void writeToParcel(Parcel dest, int flags) {
dest.writeCharSequence(mText);
dest.writeString(mCallingPackageName);
dest.writeInt(mUserId);
dest.writeBundle(mExtra);
}
private static Request readFromParcel(Parcel in) {
final CharSequence text = in.readCharSequence();
final String callingPackageName = in.readString();
final int userId = in.readInt();
final Bundle extra = in.readBundle();
final Request request = new Request(text, extra);
request.setCallingPackageName(callingPackageName);
request.setUserId(userId);
return request;
}

View File

@@ -20,11 +20,13 @@ import android.annotation.FloatRange;
import android.annotation.IntDef;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.UserIdInt;
import android.content.Context;
import android.os.Bundle;
import android.os.LocaleList;
import android.os.Parcel;
import android.os.Parcelable;
import android.os.UserHandle;
import android.text.Spannable;
import android.text.method.MovementMethod;
import android.text.style.ClickableSpan;
@@ -339,6 +341,8 @@ public final class TextLinks implements Parcelable {
private final boolean mLegacyFallback;
@Nullable private String mCallingPackageName;
private final Bundle mExtras;
@UserIdInt
private int mUserId = UserHandle.USER_NULL;
private Request(
CharSequence text,
@@ -409,6 +413,24 @@ public final class TextLinks implements Parcelable {
return mCallingPackageName;
}
/**
* Sets the id of the user that sent this request.
* <p>
* Package-private for SystemTextClassifier's use.
*/
void setUserId(@UserIdInt int userId) {
mUserId = userId;
}
/**
* Returns the id of the user that sent this request.
* @hide
*/
@UserIdInt
public int getUserId() {
return mUserId;
}
/**
* Returns the extended data.
*
@@ -509,6 +531,7 @@ public final class TextLinks implements Parcelable {
dest.writeParcelable(mDefaultLocales, flags);
dest.writeParcelable(mEntityConfig, flags);
dest.writeString(mCallingPackageName);
dest.writeInt(mUserId);
dest.writeBundle(mExtras);
}
@@ -517,11 +540,13 @@ public final class TextLinks implements Parcelable {
final LocaleList defaultLocales = in.readParcelable(null);
final EntityConfig entityConfig = in.readParcelable(null);
final String callingPackageName = in.readString();
final int userId = in.readInt();
final Bundle extras = in.readBundle();
final Request request = new Request(text, defaultLocales, entityConfig,
/* legacyFallback= */ true, extras);
request.setCallingPackageName(callingPackageName);
request.setUserId(userId);
return request;
}

View File

@@ -20,10 +20,12 @@ import android.annotation.FloatRange;
import android.annotation.IntRange;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.UserIdInt;
import android.os.Bundle;
import android.os.LocaleList;
import android.os.Parcel;
import android.os.Parcelable;
import android.os.UserHandle;
import android.text.SpannedString;
import android.util.ArrayMap;
import android.view.textclassifier.TextClassifier.EntityType;
@@ -211,6 +213,8 @@ public final class TextSelection implements Parcelable {
private final boolean mDarkLaunchAllowed;
private final Bundle mExtras;
@Nullable private String mCallingPackageName;
@UserIdInt
private int mUserId = UserHandle.USER_NULL;
private Request(
CharSequence text,
@@ -291,6 +295,24 @@ public final class TextSelection implements Parcelable {
return mCallingPackageName;
}
/**
* Sets the id of the user that sent this request.
* <p>
* Package-private for SystemTextClassifier's use.
*/
void setUserId(@UserIdInt int userId) {
mUserId = userId;
}
/**
* Returns the id of the user that sent this request.
* @hide
*/
@UserIdInt
public int getUserId() {
return mUserId;
}
/**
* Returns the extended data.
*
@@ -394,6 +416,7 @@ public final class TextSelection implements Parcelable {
dest.writeInt(mEndIndex);
dest.writeParcelable(mDefaultLocales, flags);
dest.writeString(mCallingPackageName);
dest.writeInt(mUserId);
dest.writeBundle(mExtras);
}
@@ -403,11 +426,13 @@ public final class TextSelection implements Parcelable {
final int endIndex = in.readInt();
final LocaleList defaultLocales = in.readParcelable(null);
final String callingPackageName = in.readString();
final int userId = in.readInt();
final Bundle extras = in.readBundle();
final Request request = new Request(text, startIndex, endIndex, defaultLocales,
/* darkLaunchAllowed= */ false, extras);
request.setCallingPackageName(callingPackageName);
request.setUserId(userId);
return request;
}

View File

@@ -11260,6 +11260,12 @@ public class TextView extends View implements ViewTreeObserver.OnPreDrawListener
return getServiceManagerForUser(getContext().getPackageName(), ClipboardManager.class);
}
@Nullable
final TextClassificationManager getTextClassificationManagerForUser() {
return getServiceManagerForUser(
getContext().getPackageName(), TextClassificationManager.class);
}
@Nullable
final <T> T getServiceManagerForUser(String packageName, Class<T> managerClazz) {
if (mTextOperationUser == null) {
@@ -12354,8 +12360,7 @@ public class TextView extends View implements ViewTreeObserver.OnPreDrawListener
@NonNull
public TextClassifier getTextClassifier() {
if (mTextClassifier == null) {
final TextClassificationManager tcm =
mContext.getSystemService(TextClassificationManager.class);
final TextClassificationManager tcm = getTextClassificationManagerForUser();
if (tcm != null) {
return tcm.getTextClassifier();
}
@@ -12371,8 +12376,7 @@ public class TextView extends View implements ViewTreeObserver.OnPreDrawListener
@NonNull
TextClassifier getTextClassificationSession() {
if (mTextClassificationSession == null || mTextClassificationSession.isDestroyed()) {
final TextClassificationManager tcm =
mContext.getSystemService(TextClassificationManager.class);
final TextClassificationManager tcm = getTextClassificationManagerForUser();
if (tcm != null) {
final String widgetType;
if (isTextEditable()) {

View File

@@ -30,6 +30,7 @@ import android.os.UserHandle;
import android.service.textclassifier.ITextClassifierCallback;
import android.service.textclassifier.ITextClassifierService;
import android.service.textclassifier.TextClassifierService;
import android.util.ArrayMap;
import android.util.Slog;
import android.util.SparseArray;
import android.view.textclassifier.ConversationActions;
@@ -54,6 +55,7 @@ import com.android.server.SystemService;
import java.io.FileDescriptor;
import java.io.PrintWriter;
import java.util.ArrayDeque;
import java.util.Map;
import java.util.Queue;
/**
@@ -119,6 +121,8 @@ public final class TextClassificationManagerService extends ITextClassifierServi
private final Object mLock;
@GuardedBy("mLock")
final SparseArray<UserState> mUserStates = new SparseArray<>();
@GuardedBy("mLock")
private final Map<TextClassificationSessionId, Integer> mSessionUserIds = new ArrayMap<>();
private TextClassificationManagerService(Context context) {
mContext = Preconditions.checkNotNull(context);
@@ -127,15 +131,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onSuggestSelection(
TextClassificationSessionId sessionId,
@Nullable TextClassificationSessionId sessionId,
TextSelection.Request request, ITextClassifierCallback callback)
throws RemoteException {
Preconditions.checkNotNull(request);
Preconditions.checkNotNull(callback);
validateInput(mContext, request.getCallingPackageName());
final int userId = request.getUserId();
validateInput(mContext, request.getCallingPackageName(), userId);
synchronized (mLock) {
UserState userState = getCallingUserStateLocked();
UserState userState = getUserStateLocked(userId);
if (!userState.bindLocked()) {
callback.onFailure();
} else if (userState.isBoundLocked()) {
@@ -150,15 +155,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onClassifyText(
TextClassificationSessionId sessionId,
@Nullable TextClassificationSessionId sessionId,
TextClassification.Request request, ITextClassifierCallback callback)
throws RemoteException {
Preconditions.checkNotNull(request);
Preconditions.checkNotNull(callback);
validateInput(mContext, request.getCallingPackageName());
final int userId = request.getUserId();
validateInput(mContext, request.getCallingPackageName(), userId);
synchronized (mLock) {
UserState userState = getCallingUserStateLocked();
UserState userState = getUserStateLocked(userId);
if (!userState.bindLocked()) {
callback.onFailure();
} else if (userState.isBoundLocked()) {
@@ -173,15 +179,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onGenerateLinks(
TextClassificationSessionId sessionId,
@Nullable TextClassificationSessionId sessionId,
TextLinks.Request request, ITextClassifierCallback callback)
throws RemoteException {
Preconditions.checkNotNull(request);
Preconditions.checkNotNull(callback);
validateInput(mContext, request.getCallingPackageName());
final int userId = request.getUserId();
validateInput(mContext, request.getCallingPackageName(), userId);
synchronized (mLock) {
UserState userState = getCallingUserStateLocked();
UserState userState = getUserStateLocked(userId);
if (!userState.bindLocked()) {
callback.onFailure();
} else if (userState.isBoundLocked()) {
@@ -196,12 +203,14 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onSelectionEvent(
TextClassificationSessionId sessionId, SelectionEvent event) throws RemoteException {
@Nullable TextClassificationSessionId sessionId, SelectionEvent event)
throws RemoteException {
Preconditions.checkNotNull(event);
validateInput(mContext, event.getPackageName());
final int userId = event.getUserId();
validateInput(mContext, event.getPackageName(), userId);
synchronized (mLock) {
UserState userState = getCallingUserStateLocked();
UserState userState = getUserStateLocked(userId);
if (userState.isBoundLocked()) {
userState.mService.onSelectionEvent(sessionId, event);
} else {
@@ -213,16 +222,19 @@ public final class TextClassificationManagerService extends ITextClassifierServi
}
@Override
public void onTextClassifierEvent(
TextClassificationSessionId sessionId,
@Nullable TextClassificationSessionId sessionId,
TextClassifierEvent event) throws RemoteException {
Preconditions.checkNotNull(event);
final String packageName = event.getEventContext() == null
? null
: event.getEventContext().getPackageName();
validateInput(mContext, packageName);
final int userId = event.getEventContext() == null
? UserHandle.getCallingUserId()
: event.getEventContext().getUserId();
validateInput(mContext, packageName, userId);
synchronized (mLock) {
UserState userState = getCallingUserStateLocked();
UserState userState = getUserStateLocked(userId);
if (userState.isBoundLocked()) {
userState.mService.onTextClassifierEvent(sessionId, event);
} else {
@@ -235,15 +247,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onDetectLanguage(
TextClassificationSessionId sessionId,
@Nullable TextClassificationSessionId sessionId,
TextLanguage.Request request,
ITextClassifierCallback callback) throws RemoteException {
Preconditions.checkNotNull(request);
Preconditions.checkNotNull(callback);
validateInput(mContext, request.getCallingPackageName());
final int userId = request.getUserId();
validateInput(mContext, request.getCallingPackageName(), userId);
synchronized (mLock) {
UserState userState = getCallingUserStateLocked();
UserState userState = getUserStateLocked(userId);
if (!userState.bindLocked()) {
callback.onFailure();
} else if (userState.isBoundLocked()) {
@@ -258,15 +271,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onSuggestConversationActions(
TextClassificationSessionId sessionId,
@Nullable TextClassificationSessionId sessionId,
ConversationActions.Request request,
ITextClassifierCallback callback) throws RemoteException {
Preconditions.checkNotNull(request);
Preconditions.checkNotNull(callback);
validateInput(mContext, request.getCallingPackageName());
final int userId = request.getUserId();
validateInput(mContext, request.getCallingPackageName(), userId);
synchronized (mLock) {
UserState userState = getCallingUserStateLocked();
UserState userState = getUserStateLocked(userId);
if (!userState.bindLocked()) {
callback.onFailure();
} else if (userState.isBoundLocked()) {
@@ -285,13 +299,15 @@ public final class TextClassificationManagerService extends ITextClassifierServi
throws RemoteException {
Preconditions.checkNotNull(sessionId);
Preconditions.checkNotNull(classificationContext);
validateInput(mContext, classificationContext.getPackageName());
final int userId = classificationContext.getUserId();
validateInput(mContext, classificationContext.getPackageName(), userId);
synchronized (mLock) {
UserState userState = getCallingUserStateLocked();
UserState userState = getUserStateLocked(userId);
if (userState.isBoundLocked()) {
userState.mService.onCreateTextClassificationSession(
classificationContext, sessionId);
mSessionUserIds.put(sessionId, userId);
} else {
userState.mPendingRequests.add(new PendingRequest(
() -> onCreateTextClassificationSession(classificationContext, sessionId),
@@ -306,9 +322,15 @@ public final class TextClassificationManagerService extends ITextClassifierServi
Preconditions.checkNotNull(sessionId);
synchronized (mLock) {
UserState userState = getCallingUserStateLocked();
final int userId = mSessionUserIds.containsKey(sessionId)
? mSessionUserIds.get(sessionId)
: UserHandle.getCallingUserId();
validateInput(mContext, null /* packageName */, userId);
UserState userState = getUserStateLocked(userId);
if (userState.isBoundLocked()) {
userState.mService.onDestroyTextClassificationSession(sessionId);
mSessionUserIds.remove(sessionId);
} else {
userState.mPendingRequests.add(new PendingRequest(
() -> onDestroyTextClassificationSession(sessionId),
@@ -317,11 +339,6 @@ public final class TextClassificationManagerService extends ITextClassifierServi
}
}
@GuardedBy("mLock")
private UserState getCallingUserStateLocked() {
return getUserStateLocked(UserHandle.getCallingUserId());
}
@GuardedBy("mLock")
private UserState getUserStateLocked(int userId) {
UserState result = mUserStates.get(userId);
@@ -356,6 +373,7 @@ public final class TextClassificationManagerService extends ITextClassifierServi
pw.decreaseIndent();
}
}
pw.println("Number of active sessions: " + mSessionUserIds.size());
}
}
@@ -420,20 +438,32 @@ public final class TextClassificationManagerService extends ITextClassifierServi
e -> Slog.d(LOG_TAG, "Error " + opDesc + ": " + e.getMessage()));
}
private static void validateInput(Context context, @Nullable String packageName)
private static void validateInput(
Context context, @Nullable String packageName, @UserIdInt int userId)
throws RemoteException {
if (packageName == null) return;
try {
final int packageUid = context.getPackageManager()
.getPackageUidAsUser(packageName, UserHandle.getCallingUserId());
final int callingUid = Binder.getCallingUid();
Preconditions.checkArgument(callingUid == packageUid
// Trust the system process:
|| callingUid == android.os.Process.SYSTEM_UID);
if (packageName != null) {
final int packageUid = context.getPackageManager()
.getPackageUidAsUser(packageName, UserHandle.getCallingUserId());
final int callingUid = Binder.getCallingUid();
Preconditions.checkArgument(callingUid == packageUid
// Trust the system process:
|| callingUid == android.os.Process.SYSTEM_UID,
"Invalid package name. Package=" + packageName
+ ", CallingUid=" + callingUid);
}
Preconditions.checkArgument(userId != UserHandle.USER_NULL, "Null userId");
final int callingUserId = UserHandle.getCallingUserId();
if (callingUserId != userId) {
context.enforceCallingOrSelfPermission(
android.Manifest.permission.INTERACT_ACROSS_USERS_FULL,
"Invalid userId. UserId=" + userId + ", CallingUserId=" + callingUserId);
}
} catch (Exception e) {
throw new RemoteException(
String.format("Invalid package: name=%s, error=%s", packageName, e));
throw new RemoteException("Invalid request: " + e.getMessage(), e,
/* enableSuppression */ true, /* writableStackTrace */ true);
}
}