From aaab02a91fcb33388d958a6402b2e77481a0d5b2 Mon Sep 17 00:00:00 2001 From: Ed Heyl Date: Tue, 17 Feb 2015 14:51:48 -0800 Subject: [PATCH] Revert "Add BLAS intrinsic." This reverts commit 25207df658d6a8a3e885c7017fcc25702363583c. --- .../android/renderscript/RenderScript.java | 49 - .../renderscript/ScriptIntrinsicBLAS.java | 1489 ----------------- rs/jni/android_renderscript_RenderScript.cpp | 142 -- 3 files changed, 1680 deletions(-) delete mode 100644 rs/java/android/renderscript/ScriptIntrinsicBLAS.java diff --git a/rs/java/android/renderscript/RenderScript.java b/rs/java/android/renderscript/RenderScript.java index e7487aa4edd22..5e150e98572b0 100644 --- a/rs/java/android/renderscript/RenderScript.java +++ b/rs/java/android/renderscript/RenderScript.java @@ -887,55 +887,6 @@ public class RenderScript { return rsnPathCreate(mContext, prim, isStatic, vtx, loop, q); } - native void rsnScriptIntrinsicBLAS_Single(long con, long id, int func, int TransA, - int TransB, int Side, int Uplo, int Diag, int M, int N, int K, - float alpha, long A, long B, float beta, long C, int incX, int incY, - int KL, int KU); - synchronized void nScriptIntrinsicBLAS_Single(long id, int func, int TransA, - int TransB, int Side, int Uplo, int Diag, int M, int N, int K, - float alpha, long A, long B, float beta, long C, int incX, int incY, - int KL, int KU) { - validate(); - rsnScriptIntrinsicBLAS_Single(mContext, id, func, TransA, TransB, Side, Uplo, Diag, M, N, K, alpha, A, B, beta, C, incX, incY, KL, KU); - } - - native void rsnScriptIntrinsicBLAS_Double(long con, long id, int func, int TransA, - int TransB, int Side, int Uplo, int Diag, int M, int N, int K, - double alpha, long A, long B, double beta, long C, int incX, int incY, - int KL, int KU); - synchronized void nScriptIntrinsicBLAS_Double(long id, int func, int TransA, - int TransB, int Side, int Uplo, int Diag, int M, int N, int K, - double alpha, long A, long B, double beta, long C, int incX, int incY, - int KL, int KU) { - validate(); - rsnScriptIntrinsicBLAS_Double(mContext, id, func, TransA, TransB, Side, Uplo, Diag, M, N, K, alpha, A, B, beta, C, incX, incY, KL, KU); - } - - native void rsnScriptIntrinsicBLAS_Complex(long con, long id, int func, int TransA, - int TransB, int Side, int Uplo, int Diag, int M, int N, int K, - float alphaX, float alphaY, long A, long B, float betaX, float betaY, long C, int incX, int incY, - int KL, int KU); - synchronized void nScriptIntrinsicBLAS_Complex(long id, int func, int TransA, - int TransB, int Side, int Uplo, int Diag, int M, int N, int K, - float alphaX, float alphaY, long A, long B, float betaX, float betaY, long C, int incX, int incY, - int KL, int KU) { - validate(); - rsnScriptIntrinsicBLAS_Complex(mContext, id, func, TransA, TransB, Side, Uplo, Diag, M, N, K, alphaX, alphaY, A, B, betaX, betaY, C, incX, incY, KL, KU); - } - - native void rsnScriptIntrinsicBLAS_Z(long con, long id, int func, int TransA, - int TransB, int Side, int Uplo, int Diag, int M, int N, int K, - double alphaX, double alphaY, long A, long B, double betaX, double betaY, long C, int incX, int incY, - int KL, int KU); - synchronized void nScriptIntrinsicBLAS_Z(long id, int func, int TransA, - int TransB, int Side, int Uplo, int Diag, int M, int N, int K, - double alphaX, double alphaY, long A, long B, double betaX, double betaY, long C, int incX, int incY, - int KL, int KU) { - validate(); - rsnScriptIntrinsicBLAS_Z(mContext, id, func, TransA, TransB, Side, Uplo, Diag, M, N, K, alphaX, alphaY, A, B, betaX, betaY, C, incX, incY, KL, KU); - } - - long mDev; long mContext; @SuppressWarnings({"FieldCanBeLocal"}) diff --git a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java b/rs/java/android/renderscript/ScriptIntrinsicBLAS.java deleted file mode 100644 index 90d2300e553f9..0000000000000 --- a/rs/java/android/renderscript/ScriptIntrinsicBLAS.java +++ /dev/null @@ -1,1489 +0,0 @@ -/* - * Copyright (C) 2015 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package android.renderscript; - -import android.annotation.IntDef; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; - -/** - * - * BLAS - * - * @hide - **/ -public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { - private Allocation mLUT; - - private ScriptIntrinsicBLAS(long id, RenderScript rs) { - super(id, rs); - } - - private static final int RsBlas_sdsdot = 1; - private static final int RsBlas_dsdot = 2; - private static final int RsBlas_sdot = 3; - private static final int RsBlas_ddot = 4; - private static final int RsBlas_cdotu_sub = 5; - private static final int RsBlas_cdotc_sub = 6; - private static final int RsBlas_zdotu_sub = 7; - private static final int RsBlas_zdotc_sub = 8; - private static final int RsBlas_snrm2 = 9; - private static final int RsBlas_sasum = 10; - private static final int RsBlas_dnrm2 = 11; - private static final int RsBlas_dasum = 12; - private static final int RsBlas_scnrm2 = 13; - private static final int RsBlas_scasum = 14; - private static final int RsBlas_dznrm2 = 15; - private static final int RsBlas_dzasum = 16; - private static final int RsBlas_isamax = 17; - private static final int RsBlas_idamax = 18; - private static final int RsBlas_icamax = 19; - private static final int RsBlas_izamax = 20; - private static final int RsBlas_sswap = 21; - private static final int RsBlas_scopy = 22; - private static final int RsBlas_saxpy = 23; - private static final int RsBlas_dswap = 24; - private static final int RsBlas_dcopy = 25; - private static final int RsBlas_daxpy = 26; - private static final int RsBlas_cswap = 27; - private static final int RsBlas_ccopy = 28; - private static final int RsBlas_caxpy = 29; - private static final int RsBlas_zswap = 30; - private static final int RsBlas_zcopy = 31; - private static final int RsBlas_zaxpy = 32; - private static final int RsBlas_srotg = 33; - private static final int RsBlas_srotmg = 34; - private static final int RsBlas_srot = 35; - private static final int RsBlas_srotm = 36; - private static final int RsBlas_drotg = 37; - private static final int RsBlas_drotmg = 38; - private static final int RsBlas_drot = 39; - private static final int RsBlas_drotm = 40; - private static final int RsBlas_sscal = 41; - private static final int RsBlas_dscal = 42; - private static final int RsBlas_cscal = 43; - private static final int RsBlas_zscal = 44; - private static final int RsBlas_csscal = 45; - private static final int RsBlas_zdscal = 46; - private static final int RsBlas_sgemv = 47; - private static final int RsBlas_sgbmv = 48; - private static final int RsBlas_strmv = 49; - private static final int RsBlas_stbmv = 50; - private static final int RsBlas_stpmv = 51; - private static final int RsBlas_strsv = 52; - private static final int RsBlas_stbsv = 53; - private static final int RsBlas_stpsv = 54; - private static final int RsBlas_dgemv = 55; - private static final int RsBlas_dgbmv = 56; - private static final int RsBlas_dtrmv = 57; - private static final int RsBlas_dtbmv = 58; - private static final int RsBlas_dtpmv = 59; - private static final int RsBlas_dtrsv = 60; - private static final int RsBlas_dtbsv = 61; - private static final int RsBlas_dtpsv = 62; - private static final int RsBlas_cgemv = 63; - private static final int RsBlas_cgbmv = 64; - private static final int RsBlas_ctrmv = 65; - private static final int RsBlas_ctbmv = 66; - private static final int RsBlas_ctpmv = 67; - private static final int RsBlas_ctrsv = 68; - private static final int RsBlas_ctbsv = 69; - private static final int RsBlas_ctpsv = 70; - private static final int RsBlas_zgemv = 71; - private static final int RsBlas_zgbmv = 72; - private static final int RsBlas_ztrmv = 73; - private static final int RsBlas_ztbmv = 74; - private static final int RsBlas_ztpmv = 75; - private static final int RsBlas_ztrsv = 76; - private static final int RsBlas_ztbsv = 77; - private static final int RsBlas_ztpsv = 78; - private static final int RsBlas_ssymv = 79; - private static final int RsBlas_ssbmv = 80; - private static final int RsBlas_sspmv = 81; - private static final int RsBlas_sger = 82; - private static final int RsBlas_ssyr = 83; - private static final int RsBlas_sspr = 84; - private static final int RsBlas_ssyr2 = 85; - private static final int RsBlas_sspr2 = 86; - private static final int RsBlas_dsymv = 87; - private static final int RsBlas_dsbmv = 88; - private static final int RsBlas_dspmv = 89; - private static final int RsBlas_dger = 90; - private static final int RsBlas_dsyr = 91; - private static final int RsBlas_dspr = 92; - private static final int RsBlas_dsyr2 = 93; - private static final int RsBlas_dspr2 = 94; - private static final int RsBlas_chemv = 95; - private static final int RsBlas_chbmv = 96; - private static final int RsBlas_chpmv = 97; - private static final int RsBlas_cgeru = 98; - private static final int RsBlas_cgerc = 99; - private static final int RsBlas_cher = 100; - private static final int RsBlas_chpr = 101; - private static final int RsBlas_cher2 = 102; - private static final int RsBlas_chpr2 = 103; - private static final int RsBlas_zhemv = 104; - private static final int RsBlas_zhbmv = 105; - private static final int RsBlas_zhpmv = 106; - private static final int RsBlas_zgeru = 107; - private static final int RsBlas_zgerc = 108; - private static final int RsBlas_zher = 109; - private static final int RsBlas_zhpr = 110; - private static final int RsBlas_zher2 = 111; - private static final int RsBlas_zhpr2 = 112; - private static final int RsBlas_sgemm = 113; - private static final int RsBlas_ssymm = 114; - private static final int RsBlas_ssyrk = 115; - private static final int RsBlas_ssyr2k = 116; - private static final int RsBlas_strmm = 117; - private static final int RsBlas_strsm = 118; - private static final int RsBlas_dgemm = 119; - private static final int RsBlas_dsymm = 120; - private static final int RsBlas_dsyrk = 121; - private static final int RsBlas_dsyr2k = 122; - private static final int RsBlas_dtrmm = 123; - private static final int RsBlas_dtrsm = 124; - private static final int RsBlas_cgemm = 125; - private static final int RsBlas_csymm = 126; - private static final int RsBlas_csyrk = 127; - private static final int RsBlas_csyr2k = 128; - private static final int RsBlas_ctrmm = 129; - private static final int RsBlas_ctrsm = 130; - private static final int RsBlas_zgemm = 131; - private static final int RsBlas_zsymm = 132; - private static final int RsBlas_zsyrk = 133; - private static final int RsBlas_zsyr2k = 134; - private static final int RsBlas_ztrmm = 135; - private static final int RsBlas_ztrsm = 136; - private static final int RsBlas_chemm = 137; - private static final int RsBlas_cherk = 138; - private static final int RsBlas_cher2k = 139; - private static final int RsBlas_zhemm = 140; - private static final int RsBlas_zherk = 141; - private static final int RsBlas_zher2k = 142; - - /** - */ - public static ScriptIntrinsicBLAS create(RenderScript rs) { - long id = rs.nScriptIntrinsicCreate(13, Element.U32(rs).getID(rs)); - return new ScriptIntrinsicBLAS(id, rs); - } - - @IntDef({NO_TRANSPOSE, TRANSPOSE, CONJ_TRANSPOSE}) - @Retention(RetentionPolicy.SOURCE) - public @interface Transpose {} - - @IntDef({UPPER, LOWER}) - @Retention(RetentionPolicy.SOURCE) - public @interface Uplo {} - - @IntDef({NON_UNIT, UNIT}) - @Retention(RetentionPolicy.SOURCE) - public @interface Diag {} - - @IntDef({LEFT, RIGHT}) - @Retention(RetentionPolicy.SOURCE) - public @interface Side {} - - public static final int NO_TRANSPOSE = 111; - public static final int TRANSPOSE = 112; - public static final int CONJ_TRANSPOSE = 113; - - public static final int UPPER = 121; - public static final int LOWER = 122; - - public static final int NON_UNIT = 131; - public static final int UNIT = 132; - - public static final int LEFT = 141; - public static final int RIGHT = 142; - - static void validateSide(@Side int Side) { - if (Side != LEFT && Side != RIGHT) { - throw new RSRuntimeException("Invalid side passed to BLAS"); - } - } - - static void validateTranspose(@Transpose int Trans) { - if (Trans != NO_TRANSPOSE && Trans != TRANSPOSE && - Trans != CONJ_TRANSPOSE) { - throw new RSRuntimeException("Invalid transpose passed to BLAS"); - } - } - - static void validateConjTranspose(@Transpose int Trans) { - if (Trans != NO_TRANSPOSE && - Trans != CONJ_TRANSPOSE) { - throw new RSRuntimeException("Invalid transpose passed to BLAS"); - } - } - - static void validateDiag(@Diag int Diag) { - if (Diag != NON_UNIT && Diag != UNIT) { - throw new RSRuntimeException("Invalid diag passed to BLAS"); - } - } - - static void validateUplo(@Uplo int Uplo) { - if (Uplo != LEFT && Uplo != RIGHT) { - throw new RSRuntimeException("Invalid uplo passed to BLAS"); - } - } - - - /** - * Level 2 BLAS - */ - - static void validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) { - validateTranspose(TransA); - int M = A.getType().getY(); - int N = A.getType().getX(); - if (!A.getType().getElement().isCompatible(e) || - !X.getType().getElement().isCompatible(e) || - !Y.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - if (X.getType().getY() > 1 || Y.getType().getY() > 1) { - throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); - } - - if (incX <= 0 || incY <= 0) { - throw new RSRuntimeException("Vector increments must be greater than 0"); - } - int expectedXDim = -1, expectedYDim = -1; - if (TransA == NO_TRANSPOSE) { - expectedXDim = 1 + (N - 1) * incX; - expectedYDim = 1 + (M - 1) * incY; - } else { - expectedXDim = 1 + (M - 1) * incX; - expectedYDim = 1 + (N - 1) * incY; - } - if (X.getType().getX() != expectedXDim || - Y.getType().getY() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for GEMV"); - } - } - void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { - validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); - } - void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { - validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); - } - void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { - validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); - } - void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { - validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); - } - - void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { - // GBMV has the same validation requirements as GEMV + KL and KU >= 0 - validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); - if (KL < 0 || KU < 0) { - throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); - } - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); - } - void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { - // GBMV has the same validation requirements as GEMV + KL and KU >= 0 - validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); - if (KL < 0 || KU < 0) { - throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); - } - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); - } - void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { - // GBMV has the same validation requirements as GEMV + KL and KU >= 0 - validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); - if (KL < 0 || KU < 0) { - throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); - } - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); - } - void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { - // GBMV has the same validation requirements as GEMV + KL and KU >= 0 - validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); - if (KL < 0 || KU < 0) { - throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); - } - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); - } - - static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) { - validateTranspose(TransA); - int N = A.getType().getY(); - if (A.getType().getX() != N) { - throw new RSRuntimeException("A must be a square matrix for TRMV"); - } - if (!A.getType().getElement().isCompatible(e) || - !X.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - if (X.getType().getY() > 1) { - throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); - } - - if (incX <= 0) { - throw new RSRuntimeException("Vector increments must be greater than 0"); - } - int expectedXDim = 1 + (N - 1) * incX; - if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for TRMV"); - } - } - - static int validateTPMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { - validateTranspose(TransA); - validateUplo(Uplo); - validateDiag(Diag); - if (!Ap.getType().getElement().isCompatible(e) || - !X.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - if (X.getType().getY() > 1) { - throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); - } - - if (Ap.getType().getY() > 1) { - throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); - } - - int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); - if (Ap.getType().getX() != ((N * (N+1)) / 2)) { - throw new RSRuntimeException("Invalid dimension for Ap"); - } - - int expectedXDim = 1 + (N - 1) * incX; - if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); - } - - return N; - } - - void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F32(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - } - void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F64(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - } - void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - } - void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - } - void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F32(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - } - void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F64(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - } - void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - } - void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBMV has the same requirements as TRMV - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - } - void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { - int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - } - void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { - int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - } - void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { - int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - } - void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { - int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - } - void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - // TRSV is the same as TRMV - validateTRMV(Element.F32(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - - } - void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - // TRSV is the same as TRMV - validateTRMV(Element.F64(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - - } - void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - // TRSV is the same as TRMV - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - - } - void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { - // TRSV is the same as TRMV - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - - } - void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F32(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - if (K < 0) { - throw new RSRuntimeException("Number of diagonals must be positive"); - } - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - } - void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F64(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - if (K < 0) { - throw new RSRuntimeException("Number of diagonals must be positive"); - } - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - } - void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - if (K < 0) { - throw new RSRuntimeException("Number of diagonals must be positive"); - } - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - } - void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { - // TBSV is the same as TRMV - validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); - int N = A.getType().getY(); - if (K < 0) { - throw new RSRuntimeException("Number of diagonals must be positive"); - } - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - } - void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { - // TPSV is same as TPMV - int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - } - void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { - // TPSV is same as TPMV - int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); - } - void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { - // TPSV is same as TPMV - int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - } - void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { - // TPSV is same as TPMV - int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); - } - - /** - * Level 2, S and D only - */ - static int validateSYMV(Element e, @Uplo int Uplo, Allocation A, Allocation X, Allocation Y, int incX, int incY) { - validateUplo(Uplo); - int N = A.getType().getY(); - if (A.getType().getX() != N) { - throw new RSRuntimeException("A must be a square matrix for SYMV"); - } - if (!A.getType().getElement().isCompatible(e) || - !X.getType().getElement().isCompatible(e) || - !Y.getType().getElement().isCompatible(e) ) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - if (X.getType().getY() > 1 || Y.getType().getY() > 1) { - throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); - } - - if (incX <= 0 || incY <= 0) { - throw new RSRuntimeException("Vector increments must be greater than 0"); - } - int expectedXDim = 1 + (N - 1) * incX; - if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); - } - int expectedYDim = 1 + (N - 1) * incY; - if (Y.getType().getX() != expectedYDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); - } - return N; - } - static int validateSPMV(Element e, @Uplo int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) { - validateUplo(Uplo); - if (!Ap.getType().getElement().isCompatible(e) || - !X.getType().getElement().isCompatible(e) || - !Y.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - if (X.getType().getY() > 1 || Y.getType().getY() > 1) { - throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); - } - - if (Ap.getType().getY() > 1) { - throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); - } - - int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); - if (Ap.getType().getX() != ((N * (N+1)) / 2)) { - throw new RSRuntimeException("Invalid dimension for Ap"); - } - - int expectedXDim = 1 + (N - 1) * incX; - if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); - } - int expectedYDim = 1 + (N - 1) * incY; - if (Y.getType().getX() != expectedYDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); - } - - return N; - } - static void validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - if (!A.getType().getElement().isCompatible(e) || - !X.getType().getElement().isCompatible(e) || - !Y.getType().getElement().isCompatible(e) ) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - - if (X.getType().getY() > 1 || Y.getType().getY() > 1) { - throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); - } - - int M = A.getType().getY(); - int N = A.getType().getX(); - - if (N < 1 || M < 1) { - throw new RSRuntimeException("M and N must be 1 or greater for GER"); - } - - int expectedXDim = 1 + (N - 1) * incX; - if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for GER"); - } - int expectedYDim = 1 + (N - 1) * incY; - if (Y.getType().getX() != expectedYDim) { - throw new RSRuntimeException("Incorrect vector dimensions for GER"); - } - - - } - static int validateSYR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation A) { - validateUplo(Uplo); - if (!A.getType().getElement().isCompatible(e) || - !X.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - - int N = A.getType().getX(); - - if (X.getType().getY() > 1) { - throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); - } - if (N != A.getType().getY()) { - throw new RSRuntimeException("A must be a symmetric matrix"); - } - - int expectedXDim = 1 + (N - 1) * incX; - if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SYR"); - } - return N; - } - static int validateSPR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Ap) { - validateUplo(Uplo); - if (!Ap.getType().getElement().isCompatible(e) || - !X.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - if (X.getType().getY() > 1) { - throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); - } - - if (Ap.getType().getY() > 1) { - throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); - } - - int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); - if (Ap.getType().getX() != ((N * (N+1)) / 2)) { - throw new RSRuntimeException("Invalid dimension for Ap"); - } - - int expectedXDim = 1 + (N - 1) * incX; - if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); - } - - return N; - } - - static int validateSYR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - validateUplo(Uplo); - if (!A.getType().getElement().isCompatible(e) || - !X.getType().getElement().isCompatible(e) || - !Y.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - - if (X.getType().getY() > 1 || Y.getType().getY() > 1) { - throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); - } - - int N = A.getType().getX(); - - if (N != A.getType().getY()) { - throw new RSRuntimeException("A must be a symmetric matrix"); - } - - int expectedXDim = 1 + (N - 1) * incX; - int expectedYDim = 1 + (N - 1) * incY; - if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SYR"); - } - return N; - - } - static int validateSPR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { - validateUplo(Uplo); - if (!Ap.getType().getElement().isCompatible(e) || - !X.getType().getElement().isCompatible(e) || - !Y.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - if (X.getType().getY() > 1 || Y.getType().getY() > 1) { - throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); - } - - if (Ap.getType().getY() > 1) { - throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); - } - - int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); - if (Ap.getType().getX() != ((N * (N+1)) / 2)) { - throw new RSRuntimeException("Invalid dimension for Ap"); - } - - int expectedXDim = 1 + (N - 1) * incX; - int expectedYDim = 1 + (N - 1) * incY; - if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { - throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); - } - - return N; - } - - void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { - int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); - } - void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { - // SBMV is the same as SYMV - int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); - } - void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) { - int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); - } - void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); - } - void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { - int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); - } - void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { - int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); - } - void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); - } - void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { - int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); - } - void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { - int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); - } - void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { - // SBMV is the same as SYMV - int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); - } - void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) { - int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); - } - void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); - } - void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { - int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); - } - void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { - int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); - } - void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); - } - void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { - int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); - } - - - /** - * Level 2, C and Z only - */ - - static void validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - if (!A.getType().getElement().isCompatible(e) || - !X.getType().getElement().isCompatible(e) || - !Y.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - if (X.getType().getY() > 1 || Y.getType().getY() > 1) { - throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); - } - - int M = A.getType().getY(); - int N = A.getType().getX(); - - int expectedXDim = 1 + (N - 1) * incX; - if (X.getType().getX() != expectedXDim) { - throw new RSRuntimeException("Incorrect vector dimensions for GERU"); - } - int expectedYDim = 1 + (N - 1) * incY; - if (Y.getType().getX() != expectedYDim) { - throw new RSRuntimeException("Incorrect vector dimensions for GERU"); - } - - } - - void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { - // HEMV is the same as SYR2 validation-wise - int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); - } - void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { - // HBMV is the same as SYR2 validation-wise - int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); - if (K < 0) { - throw new RSRuntimeException("K must be 0 or greater for HBMV"); - } - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); - } - void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { - // HPMV is the same as SPR2 - int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); - } - void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); - } - void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - // same as GERU - validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); - } - void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { - // same as SYR - int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); - } - void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { - // equivalent to SPR for validation - int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); - } - void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - // same as SYR2 - int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); - } - void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { - // same as SPR2 - int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); - } - void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { - // HEMV is the same as SYR2 validation-wise - int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); - } - void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { - // HBMV is the same as SYR2 validation-wise - int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); - if (K < 0) { - throw new RSRuntimeException("K must be 0 or greater for HBMV"); - } - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); - } - void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { - // HPMV is the same as SPR2 - int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); - } - void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); - } - void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - // same as GERU - validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); - int M = A.getType().getY(); - int N = A.getType().getX(); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); - } - void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { - // same as SYR - int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); - } - void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { - // equivalent to SPR for validation - int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); - } - void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { - // same as SYR2 - int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); - } - void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { - // same as SPR2 - int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); - } - - - /** - * Level 3 BLAS - */ - - static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { - int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1; - if ((A != null && !A.getType().getElement().isCompatible(e)) || - (B != null && !B.getType().getElement().isCompatible(e)) || - (C != null && !C.getType().getElement().isCompatible(e))) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - if (C != null) { - cX = C.getType().getY(); - cY = C.getType().getX(); - } - if (Side == RIGHT) { - if (B != null) { - bX = A.getType().getY(); - bY = A.getType().getX(); - } - if (A != null) { - aX = B.getType().getY(); - aY = B.getType().getX(); - } - } else { - if (A != null) { - if (TransA == TRANSPOSE) { - aY = A.getType().getY(); - aX = A.getType().getX(); - } else { - aX = A.getType().getY(); - aY = A.getType().getX(); - } - } - if (B != null) { - if (TransB == TRANSPOSE) { - bY = B.getType().getY(); - bX = B.getType().getX(); - } else { - bX = B.getType().getY(); - bY = B.getType().getX(); - } - } - } - if (A != null && B != null && C != null) { - if (aY != bX || aX != cX || bY != cY) { - throw new RSRuntimeException("Called BLAS with invalid dimensions"); - } - } else if (A != null && C != null) { - // A and C only - if (aX != cY || aY != cX) { - throw new RSRuntimeException("Called BLAS with invalid dimensions"); - } - } else if (A != null && B != null) { - // A and B only - } - - } - - public void SGEMM(@Transpose int TransA, @Transpose int TransB, float alpha, Allocation A, - Allocation B, float beta, Allocation C) { - validateTranspose(TransA); - validateTranspose(TransB); - validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); - - int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { - M = A.getType().getX(); - K = A.getType().getY(); - } else { - M = A.getType().getY(); - K = A.getType().getX(); - } - if (TransB == TRANSPOSE) { - N = B.getType().getY(); - } else { - N = B.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS), - beta, C.getID(mRS), 0, 0, 0, 0); - } - public void DGEMM(@Transpose int TransA, @Transpose int TransB, double alpha, Allocation A, - Allocation B, double beta, Allocation C) { - validateTranspose(TransA); - validateTranspose(TransB); - validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); - int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { - M = A.getType().getX(); - K = A.getType().getY(); - } else { - M = A.getType().getY(); - K = A.getType().getX(); - } - if (TransB == TRANSPOSE) { - N = B.getType().getY(); - } else { - N = B.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS), - beta, C.getID(mRS), 0, 0, 0, 0); - } - public void CGEMM(@Transpose int TransA, @Transpose int TransB, Float2 alpha, Allocation A, - Allocation B, Float2 beta, Allocation C) { - validateTranspose(TransA); - validateTranspose(TransB); - validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); - int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { - M = A.getType().getX(); - K = A.getType().getY(); - } else { - M = A.getType().getY(); - K = A.getType().getX(); - } - if (TransB == TRANSPOSE) { - N = B.getType().getY(); - } else { - N = B.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), - beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); - } - - public void ZGEMM(@Transpose int TransA, @Transpose int TransB, Double2 alpha, Allocation A, - Allocation B, Double2 beta, Allocation C) { - validateTranspose(TransA); - validateTranspose(TransB); - validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); - int M = -1, N = -1, K = -1; - if (TransA == TRANSPOSE) { - M = A.getType().getX(); - K = A.getType().getY(); - } else { - M = A.getType().getY(); - K = A.getType().getX(); - } - if (TransB == TRANSPOSE) { - N = B.getType().getY(); - } else { - N = B.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), - beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); - } - - public void SSYMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, - Allocation B, float beta, Allocation C) { - validateSide(Side); - validateUplo(Uplo); - validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), - beta, C.getID(mRS), 0, 0, 0, 0); - } - public void DSYMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, - Allocation B, double beta, Allocation C) { - validateSide(Side); - validateUplo(Uplo); - validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), - beta, C.getID(mRS), 0, 0, 0, 0); - } - public void CSYMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, - Allocation B, Float2 beta, Allocation C) { - validateSide(Side); - validateUplo(Uplo); - validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), - beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); - } - public void ZSYMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, - Allocation B, Double2 beta, Allocation C) { - validateSide(Side); - validateUplo(Uplo); - validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), - beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); - } - - public void SSYRK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) { - validateTranspose(Trans); - validateUplo(Uplo); - validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); - int K = -1; - if (Trans == TRANSPOSE) { - K = A.getType().getY(); - } else { - K = A.getType().getX(); - } - - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0); - } - - public void DSYRK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) { - validateTranspose(Trans); - validateUplo(Uplo); - validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); - int K = -1; - if (Trans == TRANSPOSE) { - K = A.getType().getY(); - } else { - K = A.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0); - } - public void CSYRK(@Uplo int Uplo, @Transpose int Trans, float alphaX, float alphaY, Allocation A, float betaX, float betaY, Allocation C) { - validateTranspose(Trans); - validateUplo(Uplo); - validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); - int K = -1; - if (Trans == TRANSPOSE) { - K = A.getType().getY(); - } else { - K = A.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY, - C.getID(mRS), 0, 0, 0, 0); - } - public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, double alphaX, double alphaY, Allocation A, double betaX, double betaY, Allocation C) { - validateTranspose(Trans); - validateUplo(Uplo); - validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); - int K = -1; - if (Trans == TRANSPOSE) { - K = A.getType().getY(); - } else { - K = A.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY, - C.getID(mRS), 0, 0, 0, 0); - } - - static void validateSYR2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) { - validateTranspose(Trans); - if (!A.getType().getElement().isCompatible(e) || - !B.getType().getElement().isCompatible(e) || - !C.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - int Cdim = -1; - // A is n x k if no transpose, k x n if transpose - // C is n x n - if (Trans == TRANSPOSE) { - // check columns versus C - Cdim = A.getType().getX(); - } else { - // check rows versus C - Cdim = A.getType().getY(); - } - if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) { - throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); - } - // A dims == B dims - if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) { - throw new RSRuntimeException("Invalid A and B in SYR2K"); - } - } - public void SSYR2K(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, Allocation B, float beta, Allocation C) { - validateUplo(Uplo); - validateSYR2K(Element.F32(mRS), Trans, A, B, C); - int K = -1; - if (Trans == TRANSPOSE) { - K = A.getType().getY(); - } else { - K = A.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); - } - public void DSYR2K(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, Allocation B, double beta, Allocation C) { - validateUplo(Uplo); - validateSYR2K(Element.F64(mRS), Trans, A, B, C); - int K = -1; - if (Trans == TRANSPOSE) { - K = A.getType().getY(); - } else { - K = A.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); - } - public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) { - validateUplo(Uplo); - validateSYR2K(Element.F32_2(mRS), Trans, A, B, C); - int K = -1; - if (Trans == TRANSPOSE) { - K = A.getType().getY(); - } else { - K = A.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); - } - public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { - validateUplo(Uplo); - validateSYR2K(Element.F64_2(mRS), Trans, A, B, C); - int K = -1; - if (Trans == TRANSPOSE) { - K = A.getType().getY(); - } else { - K = A.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); - } - - static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { - validateSide(Side); - validateTranspose(TransA); - int aX = -1, aY = -1, bX = -1, bY = -1; - if (!A.getType().getElement().isCompatible(e) || - !B.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - if (TransA == TRANSPOSE) { - aY = A.getType().getY(); - aX = A.getType().getX(); - } else { - aY = A.getType().getX(); - aX = A.getType().getY(); - } - bX = B.getType().getY(); - bY = B.getType().getX(); - if (Side == LEFT) { - if (aX == 0 || aY != bX) { - throw new RSRuntimeException("Called TRMM with invalid matrices"); - } - } else { - if (bY != aX || aY == 0) { - throw new RSRuntimeException("Called TRMM with invalid matrices"); - } - } - } - public void STRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) { - validateUplo(Uplo); - validateDiag(Diag); - validateTRMM(Element.F32(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, - alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0); - } - public void DTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) { - validateUplo(Uplo); - validateDiag(Diag); - validateTRMM(Element.F64(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, - alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0); - } - public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { - validateUplo(Uplo); - validateDiag(Diag); - validateTRMM(Element.F32_2(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, - alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); - } - public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { - validateUplo(Uplo); - validateDiag(Diag); - validateTRMM(Element.F64_2(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, - alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); - } - - static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { - int adim = -1, bX = -1, bY = -1; - validateSide(Side); - validateTranspose(TransA); - if (!A.getType().getElement().isCompatible(e) || - !B.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - adim = A.getType().getX(); - if (adim != A.getType().getY()) { - // this may be unnecessary, the restriction could potentially be relaxed - // A needs to contain at least that symmetric matrix but could theoretically be larger - // for now we assume adapters are sufficient, will reevaluate in the future - throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); - } - bX = B.getType().getY(); - bY = B.getType().getX(); - if (Side == LEFT) { - // A is M*M - if (adim != bY) { - throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); - } - } else { - // A is N*N - if (adim != bX) { - throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); - } - } - } - public void STRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) { - validateUplo(Uplo); - validateDiag(Diag); - validateTRSM(Element.F32(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, - alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); - } - public void DTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) { - validateUplo(Uplo); - validateDiag(Diag); - validateTRSM(Element.F64(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, - alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); - } - public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { - validateUplo(Uplo); - validateDiag(Diag); - validateTRSM(Element.F32_2(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, - alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); - } - public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { - validateUplo(Uplo); - validateDiag(Diag); - validateTRSM(Element.F64_2(mRS), Side, TransA, A, B); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, - alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); - } - - static void validateHEMM(Element e, @Side int Side, Allocation A, Allocation B, Allocation C) { - validateSide(Side); - - if (!A.getType().getElement().isCompatible(e) || - !B.getType().getElement().isCompatible(e) || - !C.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - - // A must be square; can potentially be relaxed similar to TRSM - int adim = A.getType().getX(); - if (adim != A.getType().getY()) { - throw new RSRuntimeException("Called HEMM with non-square A"); - } - if ((Side == LEFT && adim != B.getType().getY()) || - (Side == RIGHT && adim != B.getType().getX())) { - throw new RSRuntimeException("Called HEMM with invalid B"); - } - if (B.getType().getX() != C.getType().getX() || - B.getType().getY() != C.getType().getY()) { - throw new RSRuntimeException("Called HEMM with mismatched B and C"); - } - } - public void CHEMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, Allocation B, float beta, Allocation C) { - validateUplo(Uplo); - validateHEMM(Element.F32_2(mRS), Side, A, B, C); - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, - alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); - } - public void ZHEMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, Allocation B, double beta, Allocation C) { - validateUplo(Uplo); - validateHEMM(Element.F32_2(mRS), Side, A, B, C); - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, - alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); - } - - static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) { - if (!A.getType().getElement().isCompatible(e) || - !C.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - validateConjTranspose(Trans); - int cdim = C.getType().getX(); - if (cdim != C.getType().getY()) { - throw new RSRuntimeException("Called HERK with non-square C"); - } - if (Trans == NO_TRANSPOSE) { - if (cdim != A.getType().getX()) { - throw new RSRuntimeException("Called HERK with invalid A"); - } - } else { - if (cdim != A.getType().getY()) { - throw new RSRuntimeException("Called HERK with invalid A"); - } - } - } - public void CHERK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) { - validateUplo(Uplo); - validateHERK(Element.F32_2(mRS), Trans, A, C); - int k = 0; - if (Trans == TRANSPOSE) { - k = A.getType().getY(); - } else { - k = A.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, - alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0); - } - public void ZHERK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) { - validateUplo(Uplo); - validateHERK(Element.F64_2(mRS), Trans, A, C); - int k = 0; - if (Trans == TRANSPOSE) { - k = A.getType().getY(); - } else { - k = A.getType().getX(); - } - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, - alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0); - } - - static void validateHER2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) { - if (!A.getType().getElement().isCompatible(e) || - !B.getType().getElement().isCompatible(e) || - !C.getType().getElement().isCompatible(e)) { - throw new RSRuntimeException("Called BLAS with wrong Element type"); - } - validateConjTranspose(Trans); - int cdim = C.getType().getX(); - if (cdim != C.getType().getY()) { - throw new RSRuntimeException("Called HER2K with non-square C"); - } - if (Trans == NO_TRANSPOSE) { - if (A.getType().getY() != cdim) { - throw new RSRuntimeException("Called HER2K with invalid matrices"); - } - } else { - if (A.getType().getX() != cdim) { - throw new RSRuntimeException("Called HER2K with invalid matrices"); - } - } - if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) { - throw new RSRuntimeException("Called HER2K with invalid A and B matrices"); - } - } - public void CHER2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, float beta, Allocation C) { - validateUplo(Uplo); - validateHER2K(Element.F32_2(mRS), Trans, A, B, C); - int k = 0; - if (Trans == NO_TRANSPOSE) { - k = A.getType().getX(); - } else { - k = A.getType().getY(); - } - mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y, - A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); - } - public void ZHER2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, double beta, Allocation C) { - validateUplo(Uplo); - validateHER2K(Element.F64_2(mRS), Trans, A, B, C); - int k = 0; - if (Trans == NO_TRANSPOSE) { - k = A.getType().getX(); - } else { - k = A.getType().getY(); - } - mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y, - A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); - } - - - -} diff --git a/rs/jni/android_renderscript_RenderScript.cpp b/rs/jni/android_renderscript_RenderScript.cpp index 2612323c48877..198cabe1b0502 100644 --- a/rs/jni/android_renderscript_RenderScript.cpp +++ b/rs/jni/android_renderscript_RenderScript.cpp @@ -309,143 +309,6 @@ nScriptGroup2Execute(JNIEnv *_env, jobject _this, jlong con, jlong groupID) { rsScriptGroupExecute((RsContext)con, (RsScriptGroup2)groupID); } -static void -nScriptIntrinsicBLAS_Single(JNIEnv *_env, jobject _this, jlong con, jlong id, jint func, jint TransA, - jint TransB, jint Side, jint Uplo, jint Diag, jint M, jint N, jint K, - jfloat alpha, jlong A, jlong B, jfloat beta, jlong C, jint incX, jint incY, - jint KL, jint KU) { - RsBlasCall call; - memset(&call, 0, sizeof(call)); - call.func = (RsBlasFunction)func; - call.transA = (RsBlasTranspose)TransA; - call.transB = (RsBlasTranspose)TransB; - call.side = (RsBlasSide)Side; - call.uplo = (RsBlasUplo)Uplo; - call.diag = (RsBlasDiag)Diag; - call.M = M; - call.N = N; - call.K = K; - call.alpha.f = alpha; - call.beta.f = beta; - call.incX = incX; - call.incY = incY; - call.KL = KL; - call.KU = KU; - - RsAllocation in_allocs[3]; - in_allocs[0] = (RsAllocation)A; - in_allocs[1] = (RsAllocation)B; - in_allocs[2] = (RsAllocation)C; - - rsScriptForEachMulti((RsContext)con, (RsScript)id, 0, - in_allocs, sizeof(in_allocs), nullptr, - &call, sizeof(call), nullptr, 0); -} - -static void -nScriptIntrinsicBLAS_Double(JNIEnv *_env, jobject _this, jlong con, jlong id, jint func, jint TransA, - jint TransB, jint Side, jint Uplo, jint Diag, jint M, jint N, jint K, - jdouble alpha, jlong A, jlong B, jdouble beta, jlong C, jint incX, jint incY, - jint KL, jint KU) { - RsBlasCall call; - memset(&call, 0, sizeof(call)); - call.func = (RsBlasFunction)func; - call.transA = (RsBlasTranspose)TransA; - call.transB = (RsBlasTranspose)TransB; - call.side = (RsBlasSide)Side; - call.uplo = (RsBlasUplo)Uplo; - call.diag = (RsBlasDiag)Diag; - call.M = M; - call.N = N; - call.K = K; - call.alpha.d = alpha; - call.beta.d = beta; - call.incX = incX; - call.incY = incY; - call.KL = KL; - call.KU = KU; - - RsAllocation in_allocs[3]; - in_allocs[0] = (RsAllocation)A; - in_allocs[1] = (RsAllocation)B; - in_allocs[2] = (RsAllocation)C; - - rsScriptForEachMulti((RsContext)con, (RsScript)id, 0, - in_allocs, sizeof(in_allocs), nullptr, - &call, sizeof(call), nullptr, 0); -} - -static void -nScriptIntrinsicBLAS_Complex(JNIEnv *_env, jobject _this, jlong con, jlong id, jint func, jint TransA, - jint TransB, jint Side, jint Uplo, jint Diag, jint M, jint N, jint K, - jfloat alphaX, jfloat alphaY, jlong A, jlong B, jfloat betaX, - jfloat betaY, jlong C, jint incX, jint incY, jint KL, jint KU) { - RsBlasCall call; - memset(&call, 0, sizeof(call)); - call.func = (RsBlasFunction)func; - call.transA = (RsBlasTranspose)TransA; - call.transB = (RsBlasTranspose)TransB; - call.side = (RsBlasSide)Side; - call.uplo = (RsBlasUplo)Uplo; - call.diag = (RsBlasDiag)Diag; - call.M = M; - call.N = N; - call.K = K; - call.alpha.c.r = alphaX; - call.alpha.c.i = alphaY; - call.beta.c.r = betaX; - call.beta.c.r = betaY; - call.incX = incX; - call.incY = incY; - call.KL = KL; - call.KU = KU; - - RsAllocation in_allocs[3]; - in_allocs[0] = (RsAllocation)A; - in_allocs[1] = (RsAllocation)B; - in_allocs[2] = (RsAllocation)C; - - rsScriptForEachMulti((RsContext)con, (RsScript)id, 0, - in_allocs, sizeof(in_allocs), nullptr, - &call, sizeof(call), nullptr, 0); -} - -static void -nScriptIntrinsicBLAS_Z(JNIEnv *_env, jobject _this, jlong con, jlong id, jint func, jint TransA, - jint TransB, jint Side, jint Uplo, jint Diag, jint M, jint N, jint K, - jdouble alphaX, jdouble alphaY, jlong A, jlong B, jdouble betaX, - jdouble betaY, jlong C, jint incX, jint incY, jint KL, jint KU) { - RsBlasCall call; - memset(&call, 0, sizeof(call)); - call.func = (RsBlasFunction)func; - call.transA = (RsBlasTranspose)TransA; - call.transB = (RsBlasTranspose)TransB; - call.side = (RsBlasSide)Side; - call.uplo = (RsBlasUplo)Uplo; - call.diag = (RsBlasDiag)Diag; - call.M = M; - call.N = N; - call.K = K; - call.alpha.z.r = alphaX; - call.alpha.z.i = alphaY; - call.beta.z.r = betaX; - call.beta.z.r = betaY; - call.incX = incX; - call.incY = incY; - call.KL = KL; - call.KU = KU; - - RsAllocation in_allocs[3]; - in_allocs[0] = (RsAllocation)A; - in_allocs[1] = (RsAllocation)B; - in_allocs[2] = (RsAllocation)C; - - rsScriptForEachMulti((RsContext)con, (RsScript)id, 0, - in_allocs, sizeof(in_allocs), nullptr, - &call, sizeof(call), nullptr, 0); -} - - static void nAssignName(JNIEnv *_env, jobject _this, jlong con, jlong obj, jbyteArray str) { @@ -2231,11 +2094,6 @@ static JNINativeMethod methods[] = { {"rsnScriptGroupExecute", "(JJ)V", (void*)nScriptGroupExecute }, {"rsnScriptGroup2Execute", "(JJ)V", (void*)nScriptGroup2Execute }, -{"rsnScriptIntrinsicBLAS_Single", "(JJIIIIIIIIIFJJFJIIII)V", (void*)nScriptIntrinsicBLAS_Single }, -{"rsnScriptIntrinsicBLAS_Double", "(JJIIIIIIIIIDJJDJIIII)V", (void*)nScriptIntrinsicBLAS_Double }, -{"rsnScriptIntrinsicBLAS_Complex", "(JJIIIIIIIIIFFJJFFJIIII)V", (void*)nScriptIntrinsicBLAS_Complex }, -{"rsnScriptIntrinsicBLAS_Z", "(JJIIIIIIIIIDDJJDDJIIII)V", (void*)nScriptIntrinsicBLAS_Z }, - {"rsnProgramStoreCreate", "(JZZZZZZIII)J", (void*)nProgramStoreCreate }, {"rsnProgramBindConstants", "(JJIJ)V", (void*)nProgramBindConstants },