diff --git a/core/java/com/android/internal/app/AbstractResolverComparator.java b/core/java/com/android/internal/app/AbstractResolverComparator.java index e091aac04c601..b7276a0450cc6 100644 --- a/core/java/com/android/internal/app/AbstractResolverComparator.java +++ b/core/java/com/android/internal/app/AbstractResolverComparator.java @@ -1,3 +1,19 @@ +/* + * Copyright 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.android.internal.app; import android.app.usage.UsageStatsManager; @@ -20,7 +36,7 @@ abstract class AbstractResolverComparator implements Comparator mTargetRanks = new HashMap<>(); + private final UserHandle mUser; + + AppPredictionServiceResolverComparator( + Context context, Intent intent, AppPredictor appPredictor, UserHandle user) { + super(context, intent); + mContext = context; + mAppPredictor = appPredictor; + mUser = user; + } + + @Override + int compare(ResolveInfo lhs, ResolveInfo rhs) { + Integer lhsRank = mTargetRanks.get(new ComponentName(lhs.activityInfo.packageName, + lhs.activityInfo.name)); + Integer rhsRank = mTargetRanks.get(new ComponentName(rhs.activityInfo.packageName, + rhs.activityInfo.name)); + if (lhsRank == null && rhsRank == null) { + return 0; + } else if (lhsRank == null) { + return -1; + } else if (rhsRank == null) { + return 1; + } + return lhsRank - rhsRank; + } + + @Override + void compute(List targets) { + List appTargets = new ArrayList<>(); + for (ResolvedComponentInfo target : targets) { + appTargets.add(new AppTarget.Builder(new AppTargetId(target.name.flattenToString())) + .setTarget(target.name.getPackageName(), mUser) + .setClassName(target.name.getClassName()).build()); + } + mAppPredictor.sortTargets(appTargets, mContext.getMainExecutor(), + sortedAppTargets -> { + for (int i = 0; i < sortedAppTargets.size(); i++) { + mTargetRanks.put(new ComponentName(sortedAppTargets.get(i).getPackageName(), + sortedAppTargets.get(i).getClassName()), i); + } + afterCompute(); + }); + } + + @Override + float getScore(ComponentName name) { + Integer rank = mTargetRanks.get(name); + if (rank == null) { + Log.w(TAG, "Score requested for unknown component."); + return 0f; + } + int consecutiveSumOfRanks = (mTargetRanks.size() - 1) * (mTargetRanks.size()) / 2; + return 1.0f - (((float) rank) / consecutiveSumOfRanks); + } + + @Override + void updateModel(ComponentName componentName) { + mAppPredictor.notifyAppTargetEvent( + new AppTargetEvent.Builder( + new AppTarget.Builder( + new AppTargetId(componentName.toString()), + componentName.getPackageName(), mUser) + .setClassName(componentName.getClassName()).build(), + ACTION_LAUNCH).build()); + } + + @Override + void destroy() { + // Do nothing. App Predictor destruction is handled by caller. + } +} diff --git a/core/java/com/android/internal/app/ChooserActivity.java b/core/java/com/android/internal/app/ChooserActivity.java index 54338bf6a1763..59e867ff9dd6e 100644 --- a/core/java/com/android/internal/app/ChooserActivity.java +++ b/core/java/com/android/internal/app/ChooserActivity.java @@ -150,6 +150,7 @@ public class ChooserActivity extends ResolverActivity { */ // TODO(b/123089490): Replace with system flag private static final boolean USE_PREDICTION_MANAGER_FOR_DIRECT_TARGETS = false; + private static final boolean USE_PREDICTION_MANAGER_FOR_SHARE_ACTIVITIES = false; // TODO(b/123088566) Share these in a better way. private static final String APP_PREDICTION_SHARE_UI_SURFACE = "share"; public static final String LAUNCH_LOCATON_DIRECT_SHARE = "direct_share"; @@ -1387,6 +1388,15 @@ public class ChooserActivity extends ResolverActivity { return USE_PREDICTION_MANAGER_FOR_DIRECT_TARGETS ? getAppPredictor() : null; } + /** + * This will return an app predictor if it is enabled for share activity sorting + * and if one exists. Otherwise, it returns null. + */ + @Nullable + private AppPredictor getAppPredictorForShareActivitesIfEnabled() { + return USE_PREDICTION_MANAGER_FOR_SHARE_ACTIVITIES ? getAppPredictor() : null; + } + void onRefinementResult(TargetInfo selectedTarget, Intent matchingIntent) { if (mRefinementResultReceiver != null) { mRefinementResultReceiver.destroy(); @@ -1491,8 +1501,10 @@ public class ChooserActivity extends ResolverActivity { PackageManager pm, Intent targetIntent, String referrerPackageName, - int launchedFromUid) { - super(context, pm, targetIntent, referrerPackageName, launchedFromUid); + int launchedFromUid, + AbstractResolverComparator resolverComparator) { + super(context, pm, targetIntent, referrerPackageName, launchedFromUid, + resolverComparator); } @Override @@ -1520,13 +1532,24 @@ public class ChooserActivity extends ResolverActivity { @VisibleForTesting protected ResolverListController createListController() { + AppPredictor appPredictor = getAppPredictorForShareActivitesIfEnabled(); + AbstractResolverComparator resolverComparator; + if (appPredictor != null) { + resolverComparator = new AppPredictionServiceResolverComparator(this, getTargetIntent(), + appPredictor, getUser()); + } else { + resolverComparator = + new ResolverRankerServiceResolverComparator(this, getTargetIntent(), + getReferrerPackageName(), null); + } + return new ChooserListController( this, mPm, getTargetIntent(), getReferrerPackageName(), - mLaunchedFromUid - ); + mLaunchedFromUid, + resolverComparator); } @VisibleForTesting diff --git a/core/java/com/android/internal/app/ResolverListController.java b/core/java/com/android/internal/app/ResolverListController.java index a3cfa8786d594..5f92cddbaa38c 100644 --- a/core/java/com/android/internal/app/ResolverListController.java +++ b/core/java/com/android/internal/app/ResolverListController.java @@ -63,14 +63,24 @@ public class ResolverListController { Intent targetIntent, String referrerPackage, int launchedFromUid) { + this(context, pm, targetIntent, referrerPackage, launchedFromUid, + new ResolverRankerServiceResolverComparator( + context, targetIntent, referrerPackage, null)); + } + + public ResolverListController( + Context context, + PackageManager pm, + Intent targetIntent, + String referrerPackage, + int launchedFromUid, + AbstractResolverComparator resolverComparator) { mContext = context; mpm = pm; mLaunchedFromUid = launchedFromUid; mTargetIntent = targetIntent; mReferrerPackage = referrerPackage; - mResolverComparator = - new ResolverRankerServiceResolverComparator( - mContext, mTargetIntent, mReferrerPackage, null); + mResolverComparator = resolverComparator; } @VisibleForTesting diff --git a/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java b/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java index 9bf4f01bab06c..726b186d8edb0 100644 --- a/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java +++ b/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java @@ -126,7 +126,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator Log.e(TAG, "Receiving null prediction results."); } mHandler.removeMessages(RESOLVER_RANKER_RESULT_TIMEOUT); - mAfterCompute.afterCompute(); + afterCompute(); } break; @@ -135,7 +135,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator Log.d(TAG, "RESOLVER_RANKER_RESULT_TIMEOUT; unbinding services"); } mHandler.removeMessages(RESOLVER_RANKER_SERVICE_RESULT); - mAfterCompute.afterCompute(); + afterCompute(); break; default: @@ -149,7 +149,6 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator super(context, intent); mCollator = Collator.getInstance(context.getResources().getConfiguration().locale); mReferrerPackage = referrerPackage; - mAfterCompute = afterCompute; mContext = context; mCurrentTime = System.currentTimeMillis(); @@ -157,6 +156,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator mStats = mUsm.queryAndAggregateUsageStats(mSinceTime, mCurrentTime); mAction = intent.getAction(); mRankerServiceName = new ComponentName(mContext, this.getClass()); + setCallBack(afterCompute); } // compute features for each target according to usage stats of targets. @@ -328,9 +328,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator mContext.unbindService(mConnection); mConnection.destroy(); } - if (mAfterCompute != null) { - mAfterCompute.afterCompute(); - } + afterCompute(); if (DEBUG) { Log.d(TAG, "Unbinded Resolver Ranker."); } @@ -513,9 +511,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator Log.e(TAG, "Error in Predict: " + e); } } - if (mAfterCompute != null) { - mAfterCompute.afterCompute(); - } + afterCompute(); } // adds select prob as the default values, according to a pre-trained Logistic Regression model.