Merge "Add AppPredictionServiceResolverComparator" into qt-dev

am: ba3b157e32

Change-Id: I385cc811982f84186748b8565269fc40b4d6779a
This commit is contained in:
George Hodulik
2019-04-18 23:17:53 -07:00
committed by android-build-merger
5 changed files with 188 additions and 17 deletions

View File

@@ -1,3 +1,19 @@
/*
* Copyright 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.internal.app;
import android.app.usage.UsageStatsManager;
@@ -20,7 +36,7 @@ abstract class AbstractResolverComparator implements Comparator<ResolvedComponen
private static final int NUM_OF_TOP_ANNOTATIONS_TO_USE = 3;
protected AfterCompute mAfterCompute;
private AfterCompute mAfterCompute;
protected final PackageManager mPm;
protected final UsageStatsManager mUsm;
protected String[] mAnnotations;
@@ -72,6 +88,13 @@ abstract class AbstractResolverComparator implements Comparator<ResolvedComponen
mAfterCompute = afterCompute;
}
protected final void afterCompute() {
final AfterCompute afterCompute = mAfterCompute;
if (afterCompute != null) {
afterCompute.afterCompute();
}
}
@Override
public final int compare(ResolvedComponentInfo lhsp, ResolvedComponentInfo rhsp) {
final ResolveInfo lhs = lhsp.getResolveInfoAt(0);

View File

@@ -0,0 +1,119 @@
/*
* Copyright 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.internal.app;
import static android.app.prediction.AppTargetEvent.ACTION_LAUNCH;
import android.app.prediction.AppPredictor;
import android.app.prediction.AppTarget;
import android.app.prediction.AppTargetEvent;
import android.app.prediction.AppTargetId;
import android.content.ComponentName;
import android.content.Context;
import android.content.Intent;
import android.content.pm.ResolveInfo;
import android.os.UserHandle;
import android.view.textclassifier.Log;
import com.android.internal.app.ResolverActivity.ResolvedComponentInfo;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Uses an {@link AppPredictor} to sort Resolver targets.
*/
class AppPredictionServiceResolverComparator extends AbstractResolverComparator {
private static final String TAG = "APSResolverComparator";
private final AppPredictor mAppPredictor;
private final Context mContext;
private final Map<ComponentName, Integer> mTargetRanks = new HashMap<>();
private final UserHandle mUser;
AppPredictionServiceResolverComparator(
Context context, Intent intent, AppPredictor appPredictor, UserHandle user) {
super(context, intent);
mContext = context;
mAppPredictor = appPredictor;
mUser = user;
}
@Override
int compare(ResolveInfo lhs, ResolveInfo rhs) {
Integer lhsRank = mTargetRanks.get(new ComponentName(lhs.activityInfo.packageName,
lhs.activityInfo.name));
Integer rhsRank = mTargetRanks.get(new ComponentName(rhs.activityInfo.packageName,
rhs.activityInfo.name));
if (lhsRank == null && rhsRank == null) {
return 0;
} else if (lhsRank == null) {
return -1;
} else if (rhsRank == null) {
return 1;
}
return lhsRank - rhsRank;
}
@Override
void compute(List<ResolvedComponentInfo> targets) {
List<AppTarget> appTargets = new ArrayList<>();
for (ResolvedComponentInfo target : targets) {
appTargets.add(new AppTarget.Builder(new AppTargetId(target.name.flattenToString()))
.setTarget(target.name.getPackageName(), mUser)
.setClassName(target.name.getClassName()).build());
}
mAppPredictor.sortTargets(appTargets, mContext.getMainExecutor(),
sortedAppTargets -> {
for (int i = 0; i < sortedAppTargets.size(); i++) {
mTargetRanks.put(new ComponentName(sortedAppTargets.get(i).getPackageName(),
sortedAppTargets.get(i).getClassName()), i);
}
afterCompute();
});
}
@Override
float getScore(ComponentName name) {
Integer rank = mTargetRanks.get(name);
if (rank == null) {
Log.w(TAG, "Score requested for unknown component.");
return 0f;
}
int consecutiveSumOfRanks = (mTargetRanks.size() - 1) * (mTargetRanks.size()) / 2;
return 1.0f - (((float) rank) / consecutiveSumOfRanks);
}
@Override
void updateModel(ComponentName componentName) {
mAppPredictor.notifyAppTargetEvent(
new AppTargetEvent.Builder(
new AppTarget.Builder(
new AppTargetId(componentName.toString()),
componentName.getPackageName(), mUser)
.setClassName(componentName.getClassName()).build(),
ACTION_LAUNCH).build());
}
@Override
void destroy() {
// Do nothing. App Predictor destruction is handled by caller.
}
}

View File

@@ -150,6 +150,7 @@ public class ChooserActivity extends ResolverActivity {
*/
// TODO(b/123089490): Replace with system flag
private static final boolean USE_PREDICTION_MANAGER_FOR_DIRECT_TARGETS = false;
private static final boolean USE_PREDICTION_MANAGER_FOR_SHARE_ACTIVITIES = false;
// TODO(b/123088566) Share these in a better way.
private static final String APP_PREDICTION_SHARE_UI_SURFACE = "share";
public static final String LAUNCH_LOCATON_DIRECT_SHARE = "direct_share";
@@ -1387,6 +1388,15 @@ public class ChooserActivity extends ResolverActivity {
return USE_PREDICTION_MANAGER_FOR_DIRECT_TARGETS ? getAppPredictor() : null;
}
/**
* This will return an app predictor if it is enabled for share activity sorting
* and if one exists. Otherwise, it returns null.
*/
@Nullable
private AppPredictor getAppPredictorForShareActivitesIfEnabled() {
return USE_PREDICTION_MANAGER_FOR_SHARE_ACTIVITIES ? getAppPredictor() : null;
}
void onRefinementResult(TargetInfo selectedTarget, Intent matchingIntent) {
if (mRefinementResultReceiver != null) {
mRefinementResultReceiver.destroy();
@@ -1491,8 +1501,10 @@ public class ChooserActivity extends ResolverActivity {
PackageManager pm,
Intent targetIntent,
String referrerPackageName,
int launchedFromUid) {
super(context, pm, targetIntent, referrerPackageName, launchedFromUid);
int launchedFromUid,
AbstractResolverComparator resolverComparator) {
super(context, pm, targetIntent, referrerPackageName, launchedFromUid,
resolverComparator);
}
@Override
@@ -1520,13 +1532,24 @@ public class ChooserActivity extends ResolverActivity {
@VisibleForTesting
protected ResolverListController createListController() {
AppPredictor appPredictor = getAppPredictorForShareActivitesIfEnabled();
AbstractResolverComparator resolverComparator;
if (appPredictor != null) {
resolverComparator = new AppPredictionServiceResolverComparator(this, getTargetIntent(),
appPredictor, getUser());
} else {
resolverComparator =
new ResolverRankerServiceResolverComparator(this, getTargetIntent(),
getReferrerPackageName(), null);
}
return new ChooserListController(
this,
mPm,
getTargetIntent(),
getReferrerPackageName(),
mLaunchedFromUid
);
mLaunchedFromUid,
resolverComparator);
}
@VisibleForTesting

View File

@@ -63,14 +63,24 @@ public class ResolverListController {
Intent targetIntent,
String referrerPackage,
int launchedFromUid) {
this(context, pm, targetIntent, referrerPackage, launchedFromUid,
new ResolverRankerServiceResolverComparator(
context, targetIntent, referrerPackage, null));
}
public ResolverListController(
Context context,
PackageManager pm,
Intent targetIntent,
String referrerPackage,
int launchedFromUid,
AbstractResolverComparator resolverComparator) {
mContext = context;
mpm = pm;
mLaunchedFromUid = launchedFromUid;
mTargetIntent = targetIntent;
mReferrerPackage = referrerPackage;
mResolverComparator =
new ResolverRankerServiceResolverComparator(
mContext, mTargetIntent, mReferrerPackage, null);
mResolverComparator = resolverComparator;
}
@VisibleForTesting

View File

@@ -126,7 +126,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
Log.e(TAG, "Receiving null prediction results.");
}
mHandler.removeMessages(RESOLVER_RANKER_RESULT_TIMEOUT);
mAfterCompute.afterCompute();
afterCompute();
}
break;
@@ -135,7 +135,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
Log.d(TAG, "RESOLVER_RANKER_RESULT_TIMEOUT; unbinding services");
}
mHandler.removeMessages(RESOLVER_RANKER_SERVICE_RESULT);
mAfterCompute.afterCompute();
afterCompute();
break;
default:
@@ -149,7 +149,6 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
super(context, intent);
mCollator = Collator.getInstance(context.getResources().getConfiguration().locale);
mReferrerPackage = referrerPackage;
mAfterCompute = afterCompute;
mContext = context;
mCurrentTime = System.currentTimeMillis();
@@ -157,6 +156,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
mStats = mUsm.queryAndAggregateUsageStats(mSinceTime, mCurrentTime);
mAction = intent.getAction();
mRankerServiceName = new ComponentName(mContext, this.getClass());
setCallBack(afterCompute);
}
// compute features for each target according to usage stats of targets.
@@ -328,9 +328,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
mContext.unbindService(mConnection);
mConnection.destroy();
}
if (mAfterCompute != null) {
mAfterCompute.afterCompute();
}
afterCompute();
if (DEBUG) {
Log.d(TAG, "Unbinded Resolver Ranker.");
}
@@ -513,9 +511,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
Log.e(TAG, "Error in Predict: " + e);
}
}
if (mAfterCompute != null) {
mAfterCompute.afterCompute();
}
afterCompute();
}
// adds select prob as the default values, according to a pre-trained Logistic Regression model.