Skip to content

Commit

Permalink
Working prototype
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Nett <[email protected]>
  • Loading branch information
rnett committed Jun 27, 2021
1 parent 8f8aadc commit c098372
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 9 deletions.
11 changes: 11 additions & 0 deletions tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
<artifactId>ndarray</artifactId>
<version>${ndarray.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.31</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
Expand All @@ -73,6 +78,12 @@
<artifactId>jmh-generator-annprocess</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-jdk14</artifactId>
<version>1.7.31</version>
<scope>test</scope>
</dependency>
</dependencies>

<profiles>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,10 @@ public final class Ops {

public final SignalOps signal;

public final QuantizationOps quantization;

public final TrainOps train;

public final QuantizationOps quantization;

private final Scope scope;

private Ops(Scope scope) {
Expand All @@ -402,8 +402,8 @@ private Ops(Scope scope) {
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
quantization = new QuantizationOps(this);
train = new TrainOps(this);
quantization = new QuantizationOps(this);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ public class TFLogEntry extends Pointer {
public native @StdString @Cast({"char*", "std::string&&"}) BytePointer FName();
public native int Line();
public native @StdString @Cast({"char*", "std::string&&"}) BytePointer ToString();
public native @StdString @Cast({"", "", "std::string"}) BytePointer text_message();

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,20 @@
@Namespace("tensorflow") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class TFLogSink extends Pointer {
static { Loader.load(); }
/** Default native constructor. */
public TFLogSink() { super((Pointer)null); allocate(); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public TFLogSink(long size) { super((Pointer)null); allocateArray(size); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public TFLogSink(Pointer p) { super(p); }
private native void allocate();
private native void allocateArray(long size);
@Override public TFLogSink position(long position) {
return (TFLogSink)super.position(position);
}
@Override public TFLogSink getPointer(long i) {
return new TFLogSink(this).position(position + i);
}


// `Send` is called synchronously during the log statement. The logging
Expand All @@ -24,7 +36,7 @@ public class TFLogSink extends Pointer {
// `e` is guaranteed to remain valid until the subsequent call to
// `WaitTillSent` completes, so implementations may store a pointer to or
// copy of `e` (e.g. in a thread local variable) for use in `WaitTillSent`.
public native void Send(@Const @ByRef TFLogEntry entry);
@Virtual(true) public native void Send(@Const @ByRef TFLogEntry entry);

// `WaitTillSent` blocks the calling thread (the thread that generated a log
// message) until the sink has finished processing the log message.
Expand All @@ -34,5 +46,5 @@ public class TFLogSink extends Pointer {
// The default implementation returns immediately. Like `Send`,
// implementations should be careful not to call `LOG` or `CHECK or take any
// locks that might be held by the `LOG` caller, to avoid deadlock.
public native void WaitTillSent();
@Virtual public native void WaitTillSent();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
package org.tensorflow;

import org.bytedeco.javacpp.Pointer;
import org.slf4j.ILoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.internal.c_api.TFLogEntry;
import org.tensorflow.internal.c_api.TFLogSink;
import org.tensorflow.internal.c_api.global.tensorflow;

public class NativeLogSink extends TFLogSink {
private static final ILoggerFactory factory = LoggerFactory.getILoggerFactory();
private static final Logger logger = LoggerFactory.getLogger(NativeLogSink.class);
NativeLogSink() {
super();
}

@Override
public void Send(TFLogEntry entry) {
//TODO make work, blocked by https://github.com/tensorflow/tensorflow/issues/44995#issuecomment-869091090
System.out.printf(
"Log message: Severity: %d, Fname: %s, line: %s, string: %s\n", entry.log_severity(), entry.FName().getString(), entry.Line(), entry.ToString().getString());
// Logger logger = factory.getLogger(entry.FName().getString());
// switch (entry.log_severity()){
// case tensorflow.kWarning:
// logger.warn(entry.ToString().getString());
// break;
// case tensorflow.kError:
// case tensorflow.kFatal:
// logger.error(entry.ToString().getString());
// break;
// default:
// logger.info(entry.ToString().getString());
// }
}

@Override
public void WaitTillSent() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.tensorflow;

import static org.tensorflow.internal.c_api.global.tensorflow.TFAddLogSink;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteBuffer;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteLibraryHandle;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_GetAllOpList;
Expand Down Expand Up @@ -96,6 +97,13 @@ public static OpList loadLibrary(String filename) {
}
}

@SuppressWarnings("FieldCanBeLocal")
private static NativeLogSink sink;
private static void setupLogger(){
sink = new NativeLogSink();
TFAddLogSink(sink);
}

private static TF_Library libraryLoad(String filename) {
try (PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
Expand Down Expand Up @@ -137,5 +145,6 @@ private TensorFlow() {}
e.printStackTrace();
throw e;
}
setupLogger();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -403,15 +403,27 @@ public void map(InfoMap infoMap) {
.cast()
.valueTypes("long")
.pointerTypes("LongPointer", "long[]"))
.put(new Info("string", "std::string", "tensorflow::string").annotations("@StdString")
.valueTypes("@Cast({\"char*\", \"std::string&&\"}) BytePointer", "@Cast({\"char*\", \"std::string&&\"}) String")
.pointerTypes("@Cast({\"char*\", \"std::string*\"}) BytePointer"))
.put(
new Info("string", "std::string", "tensorflow::string")
.annotations("@StdString")
.valueTypes(
"@Cast({\"char*\", \"std::string&&\"}) BytePointer",
"@Cast({\"char*\", \"std::string&&\"}) String")
.pointerTypes("@Cast({\"char*\", \"std::string*\"}) BytePointer"))
// .put(
// new Info("absl::string_view")
// .annotations("@StdString")
// .valueTypes(
// "@Cast({\"char*\", \"std::string&&\", \"std::string\"}) BytePointer",
// "@Cast({\"char*\", \"std::string&&\", \"std::string\"}) String"))
.put(
new Info("absl::LogSeverity", "LogSeverity", "tensorflow::LogSeverity")
.cast()
.valueTypes("int")
.pointerTypes("IntPointer", "int[]"))
.put(new Info("tensorflow::TFLogEntry").purify())
.put(new Info("tensorflow::TFLogSink").virtualize())
.put(new Info("tensorflow::TFLogEntry::text_message").skip())
.put(
new Info(
"tensorflow::internal::LogEveryNSecState",
Expand Down
6 changes: 6 additions & 0 deletions tensorflow-framework/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@
<artifactId>jmh-generator-annprocess</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-jdk14</artifactId>
<version>1.7.31</version>
<scope>test</scope>
</dependency>
<!-- Include native binaries dependencies only for testing -->
<dependency>
<groupId>org.tensorflow</groupId>
Expand Down

0 comments on commit c098372

Please sign in to comment.