Skip to content
This repository has been archived by the owner on Aug 28, 2024. It is now read-only.

[android][native_app] App example of linking to gradle deps native li… #144

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions NativeApp/app/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
cmake_minimum_required(VERSION 3.4.1)
set(TARGET pytorch_nativeapp)
project(${TARGET} CXX)
set(CMAKE_CXX_STANDARD 14)

set(build_DIR ${CMAKE_SOURCE_DIR}/build)

set(pytorch_testapp_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
file(GLOB pytorch_testapp_SOURCES
${pytorch_testapp_cpp_DIR}/pytorch_nativeapp.cpp
)

add_library(${TARGET} SHARED
${pytorch_testapp_SOURCES}
)

file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")

target_compile_options(${TARGET} PRIVATE
-fexceptions
)

set(BUILD_SUBDIR ${ANDROID_ABI})

find_library(PYTORCH_LIBRARY pytorch_jni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
find_library(FBJNI_LIBRARY fbjni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)

# OpenCV
if(NOT DEFINED ENV{OPENCV_ANDROID_SDK})
message(FATAL_ERROR "Environment var OPENCV_ANDROID_SDK set")
endif()

set(OPENCV_INCLUDE_DIR "$ENV{OPENCV_ANDROID_SDK}/sdk/native/jni/include")

target_include_directories(${TARGET} PRIVATE
"${OPENCV_INCLUDE_DIR}"
${PYTORCH_INCLUDE_DIRS})

set(OPENCV_LIB_DIR "$ENV{OPENCV_ANDROID_SDK}/sdk/native/libs/${ANDROID_ABI}")

find_library(OPENCV_LIBRARY opencv_java4
PATHS ${OPENCV_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)

target_link_libraries(${TARGET}
${PYTORCH_LIBRARY}
${FBJNI_LIBRARY}
${OPENCV_LIBRARY}
log)
70 changes: 70 additions & 0 deletions NativeApp/app/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
apply plugin: 'com.android.application'

repositories {
jcenter()
maven {
url "https://oss.sonatype.org/content/repositories/snapshots"
}
}

android {
configurations {
extractForNativeBuild
}
compileSdkVersion 28
buildToolsVersion "29.0.2"
defaultConfig {
applicationId "org.pytorch.nativeapp"
minSdkVersion 21
targetSdkVersion 28
versionCode 1
versionName "1.0"
externalNativeBuild {
cmake {
arguments "-DANDROID_STL=c++_shared"
}
}
}
buildTypes {
release {
minifyEnabled false
}
}
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}
sourceSets {
main {
jniLibs.srcDirs = ['src/main/jniLibs']
}
}
}

dependencies {
implementation 'com.android.support:appcompat-v7:28.0.0'

implementation 'org.pytorch:pytorch_android:1.6.0-SNAPSHOT'
extractForNativeBuild 'org.pytorch:pytorch_android:1.6.0-SNAPSHOT'
}

task extractAARForNativeBuild {
doLast {
configurations.extractForNativeBuild.files.each {
def file = it.absoluteFile
copy {
from zipTree(file)
into "$buildDir/$file.name"
include "headers/**"
include "jni/**"
}
}
}
}

tasks.whenTaskAdded { task ->
if (task.name.contains('externalNativeBuild')) {
task.dependsOn(extractAARForNativeBuild)
}
}
19 changes: 19 additions & 0 deletions NativeApp/app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.pytorch.nativeapp">

<application
android:allowBackup="true"
android:label="PyTorchNativeApp"
android:supportsRtl="true"
android:theme="@style/Theme.AppCompat.Light.DarkActionBar">

<activity android:name=".MainActivity">
<intent-filter>
<action android:name="android.intent.action.MAIN" />

<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>
3 changes: 3 additions & 0 deletions NativeApp/app/src/main/assets/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*
*/
!.gitignore
98 changes: 98 additions & 0 deletions NativeApp/app/src/main/cpp/pytorch_nativeapp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#include <android/log.h>
#include <cassert>
#include <cmath>
#include <pthread.h>
#include <unistd.h>
#include <vector>
#define ALOGI(...) \
__android_log_print(ANDROID_LOG_INFO, "PyTorchNativeApp", __VA_ARGS__)
#define ALOGE(...) \
__android_log_print(ANDROID_LOG_ERROR, "PyTorchNativeApp", __VA_ARGS__)

#include "jni.h"

#include <opencv2/opencv.hpp>
#include <torch/script.h>

namespace pytorch_nativeapp {
namespace {
torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
cv::Mat image_mat(/*rows=*/image.size(0),
/*cols=*/image.size(1),
/*type=*/CV_32FC1,
/*data=*/image.data_ptr<float>());
cv::Mat warp_mat(/*rows=*/warp.size(0),
/*cols=*/warp.size(1),
/*type=*/CV_32FC1,
/*data=*/warp.data_ptr<float>());

cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{8, 8});

torch::Tensor output =
torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{8, 8});
return output.clone();
}

static auto registry =
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);

template <typename T> void log(const char *m, T t) {
std::ostringstream os;
os << t << std::endl;
ALOGI("%s %s", m, os.str().c_str());
}

struct JITCallGuard {
torch::autograd::AutoGradMode no_autograd_guard{false};
torch::AutoNonVariableTypeMode non_var_guard{true};
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
} // namespace

static void loadAndForwardModel(JNIEnv *env, jclass, jstring jModelPath) {
const char *modelPath = env->GetStringUTFChars(jModelPath, 0);
assert(modelPath);

// To load torchscript model for mobile we need set these guards,
// because mobile build doesn't support features like autograd for smaller
// build size which is placed in `struct JITCallGuard` in this example. It may
// change in future, you can track the latest changes keeping an eye in
// android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
JITCallGuard guard;
torch::jit::Module module = torch::jit::load(modelPath);
module.eval();
torch::Tensor x = torch::randn({4, 8});
torch::Tensor y = torch::randn({8, 5});
log("x:", x);
log("y:", y);
c10::IValue t_out = module.forward({x, y});
log("result:", t_out);
env->ReleaseStringUTFChars(jModelPath, modelPath);
}
} // namespace pytorch_nativeapp

JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *) {
JNIEnv *env;
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK) {
return JNI_ERR;
}

jclass c = env->FindClass("org/pytorch/nativeapp/NativeClient$NativePeer");
if (c == nullptr) {
return JNI_ERR;
}

static const JNINativeMethod methods[] = {
{"loadAndForwardModel", "(Ljava/lang/String;)V",
(void *)pytorch_nativeapp::loadAndForwardModel},
};
int rc = env->RegisterNatives(c, methods,
sizeof(methods) / sizeof(JNINativeMethod));

if (rc != JNI_OK) {
return rc;
}

return JNI_VERSION_1_6;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.pytorch.nativeapp;

import android.content.Context;
import android.os.Bundle;
import android.util.Log;
import androidx.appcompat.app.AppCompatActivity;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

public class MainActivity extends AppCompatActivity {

private static final String TAG = "PyTorchNativeApp";

public static String assetFilePath(Context context, String assetName) {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}

try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
Log.e(TAG, "Error process asset " + assetName + " to file path");
}
return null;
}

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
final String modelFileAbsoluteFilePath =
new File(assetFilePath(this, "compute.pt")).getAbsolutePath();
NativeClient.loadAndForwardModel(modelFileAbsoluteFilePath);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.pytorch.nativeapp;

public final class NativeClient {

public static void loadAndForwardModel(final String modelPath) {
NativePeer.loadAndForwardModel(modelPath);
}

private static class NativePeer {
static {
System.loadLibrary("pytorch_nativeapp");
}

private static native void loadAndForwardModel(final String modelPath);
}
}
3 changes: 3 additions & 0 deletions NativeApp/app/src/main/jniLibs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*
*/
!.gitignore
20 changes: 20 additions & 0 deletions NativeApp/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
buildscript {
repositories {
google()
jcenter()
}
dependencies {
classpath 'com.android.tools.build:gradle:3.5.0'
}
}

allprojects {
repositories {
google()
jcenter()
}
}

task clean(type: Delete) {
delete rootProject.buildDir
}
3 changes: 3 additions & 0 deletions NativeApp/gradle.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
android.useAndroidX=true
android.enableJetifier=true

52 changes: 52 additions & 0 deletions NativeApp/make_warp_perspective_pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
import torch.utils.cpp_extension

print(torch.version.__version__)
op_source = """
#include <opencv2/opencv.hpp>
#include <torch/script.h>

torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
cv::Mat image_mat(/*rows=*/image.size(0),
/*cols=*/image.size(1),
/*type=*/CV_32FC1,
/*data=*/image.data_ptr<float>());
cv::Mat warp_mat(/*rows=*/warp.size(0),
/*cols=*/warp.size(1),
/*type=*/CV_32FC1,
/*data=*/warp.data_ptr<float>());

cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{64, 64});

torch::Tensor output =
torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{64, 64});
return output.clone();
}

static auto registry =
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
"""

torch.utils.cpp_extension.load_inline(
name="warp_perspective",
cpp_sources=op_source,
extra_ldflags=["-lopencv_core", "-lopencv_imgproc"],
is_python_module=False,
verbose=True,
)

print(torch.ops.my_ops.warp_perspective)


@torch.jit.script
def compute(x, y):
if bool(x[0][0] == 42):
z = 5
else:
z = 10
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
return x.matmul(y) + z


compute.save("app/src/main/assets/compute.pt")
1 change: 1 addition & 0 deletions NativeApp/settings.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include ':app'