Files
pytorch/test/cpp/api/integration.cpp

404 lines
12 KiB
C++
Raw Normal View History

#include <catch.hpp>
2018-05-30 08:55:34 -07:00
#include <torch/nn/modules/batchnorm.h>
#include <torch/nn/modules/conv.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/modules/linear.h>
#include <torch/optim/adam.h>
#include <torch/optim/optimizer.h>
#include <torch/optim/sgd.h>
#include <torch/tensor.h>
#include <torch/tensor_list_view.h>
2018-05-30 08:55:34 -07:00
#include <torch/utils.h>
#include <test/cpp/api/util.h>
using namespace torch::nn;
#include <cmath>
#include <iostream>
#include <random>
class CartPole {
// Translated from openai/gym's cartpole.py
public:
double gravity = 9.8;
double masscart = 1.0;
double masspole = 0.1;
double total_mass = (masspole + masscart);
double length = 0.5; // actually half the pole's length;
double polemass_length = (masspole * length);
double force_mag = 10.0;
double tau = 0.02; // seconds between state updates;
// Angle at which to fail the episode
double theta_threshold_radians = 12 * 2 * M_PI / 360;
double x_threshold = 2.4;
int steps_beyond_done = -1;
torch::Tensor state;
double reward;
bool done;
int step_ = 0;
torch::Tensor getState() {
return state;
}
double getReward() {
return reward;
}
double isDone() {
return done;
}
void reset() {
Create ATen tensors via TensorOptions (#7869) * Created TensorOptions Storing the type in TensorOptions to solve the Variable problem Created convenience creation functions for TensorOptions and added tests Converted zeros to TensorOptions Converted rand to TensorOptions Fix codegen for TensorOptions and multiple arguments Put TensorOptions convenience functions into torch namespace too All factory functions except *_like support TensorOptions Integrated with recent JIT changes Support *_like functions Fix in place modification Some cleanups and fixes Support sparse_coo_tensor Fix bug in Type.cpp Fix .empty calls in C++ API Fix bug in Type.cpp Trying to fix device placement Make AutoGPU CPU compatible Remove some auto_gpu.h uses Fixing some headers Fix some remaining CUDA/AutoGPU issues Fix some AutoGPU uses Fixes to dispatch_tensor_conversion Reset version of new variables to zero Implemented parsing device strings Random fixes to tests Self review cleanups flake8 Undo changes to variable.{h,cpp} because they fail on gcc7.2 Add [cuda] tag to tensor_options_cuda.cpp Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks Fix linker error in AutoGPU.cpp Fix bad merge conflict in native_functions.yaml Fixed caffe2/contrib/aten Fix new window functions added to TensorFactories.cpp * Removed torch::TensorOptions Added code to generate wrapper functions for factory methods Add implicit constructor from Backend to TensorOptions Remove Var() from C++ API and use torch:: functions Use torch:: functions more subtly in C++ API Make AutoGPU::set_device more exception safe Check status directly in DynamicCUDAHooksInterface Rename AutoGPU to DeviceGuard Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad remove python_default_init: self.type() Add back original factory functions, but with deprecation warnings Disable DeviceGuard for a couple functions in ATen Remove print statement Fix DeviceGuard construction from undefined tensor Fixing CUDA device compiler issues Moved as many methods as possible into header files Dont generate python functions for deprecated factories Remove merge conflict artefact Fix tensor_options_cuda.cpp Fix set_requires_grad not being checked Fix tensor_new.h TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac Fix bug in DeviceGuard.h Missing includes TEMPORARILY moving a few more methods into .cpp to see if it fixes windows Fixing linker errors * Fix up SummaryOps to use new factories Undo device agnostic behavior of DeviceGuard Use -1 instead of optional for default device index Also move DeviceGuard methods into header Fixes around device index after optional -> int32_t switch Fix use of DeviceGuard in new_with_tensor_copy Fix tensor_options.cpp * Fix Type::copy( * Remove test_non_float_params from ONNX tests * Set requires_grad=False in ONNX tests that use ints * Put layout/dtype/device on Tensor * Post merge fixes * Change behavior of DeviceGuard to match AutoGPU * Fix C++ API integration tests * Fix flip functions
2018-06-16 00:40:35 -07:00
state = torch::empty({4}).uniform_(-0.05, 0.05);
steps_beyond_done = -1;
step_ = 0;
}
CartPole() {
reset();
}
void step(int action) {
auto x = state[0].toCFloat();
auto x_dot = state[1].toCFloat();
auto theta = state[2].toCFloat();
auto theta_dot = state[3].toCFloat();
auto force = (action == 1) ? force_mag : -force_mag;
auto costheta = std::cos(theta);
auto sintheta = std::sin(theta);
auto temp = (force + polemass_length * theta_dot * theta_dot * sintheta) /
total_mass;
auto thetaacc = (gravity * sintheta - costheta * temp) /
(length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass));
auto xacc = temp - polemass_length * thetaacc * costheta / total_mass;
x = x + tau * x_dot;
x_dot = x_dot + tau * xacc;
theta = theta + tau * theta_dot;
theta_dot = theta_dot + tau * thetaacc;
Create ATen tensors via TensorOptions (#7869) * Created TensorOptions Storing the type in TensorOptions to solve the Variable problem Created convenience creation functions for TensorOptions and added tests Converted zeros to TensorOptions Converted rand to TensorOptions Fix codegen for TensorOptions and multiple arguments Put TensorOptions convenience functions into torch namespace too All factory functions except *_like support TensorOptions Integrated with recent JIT changes Support *_like functions Fix in place modification Some cleanups and fixes Support sparse_coo_tensor Fix bug in Type.cpp Fix .empty calls in C++ API Fix bug in Type.cpp Trying to fix device placement Make AutoGPU CPU compatible Remove some auto_gpu.h uses Fixing some headers Fix some remaining CUDA/AutoGPU issues Fix some AutoGPU uses Fixes to dispatch_tensor_conversion Reset version of new variables to zero Implemented parsing device strings Random fixes to tests Self review cleanups flake8 Undo changes to variable.{h,cpp} because they fail on gcc7.2 Add [cuda] tag to tensor_options_cuda.cpp Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks Fix linker error in AutoGPU.cpp Fix bad merge conflict in native_functions.yaml Fixed caffe2/contrib/aten Fix new window functions added to TensorFactories.cpp * Removed torch::TensorOptions Added code to generate wrapper functions for factory methods Add implicit constructor from Backend to TensorOptions Remove Var() from C++ API and use torch:: functions Use torch:: functions more subtly in C++ API Make AutoGPU::set_device more exception safe Check status directly in DynamicCUDAHooksInterface Rename AutoGPU to DeviceGuard Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad remove python_default_init: self.type() Add back original factory functions, but with deprecation warnings Disable DeviceGuard for a couple functions in ATen Remove print statement Fix DeviceGuard construction from undefined tensor Fixing CUDA device compiler issues Moved as many methods as possible into header files Dont generate python functions for deprecated factories Remove merge conflict artefact Fix tensor_options_cuda.cpp Fix set_requires_grad not being checked Fix tensor_new.h TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac Fix bug in DeviceGuard.h Missing includes TEMPORARILY moving a few more methods into .cpp to see if it fixes windows Fixing linker errors * Fix up SummaryOps to use new factories Undo device agnostic behavior of DeviceGuard Use -1 instead of optional for default device index Also move DeviceGuard methods into header Fixes around device index after optional -> int32_t switch Fix use of DeviceGuard in new_with_tensor_copy Fix tensor_options.cpp * Fix Type::copy( * Remove test_non_float_params from ONNX tests * Set requires_grad=False in ONNX tests that use ints * Put layout/dtype/device on Tensor * Post merge fixes * Change behavior of DeviceGuard to match AutoGPU * Fix C++ API integration tests * Fix flip functions
2018-06-16 00:40:35 -07:00
state.data()[0] = x;
state.data()[1] = x_dot;
state.data()[2] = theta;
state.data()[3] = theta_dot;
done = x < -x_threshold || x > x_threshold ||
theta < -theta_threshold_radians || theta > theta_threshold_radians ||
step_ > 200;
if (!done) {
reward = 1.0;
} else if (steps_beyond_done == -1) {
// Pole just fell!
steps_beyond_done = 0;
reward = 0;
} else {
if (steps_beyond_done == 0) {
2018-05-11 18:56:53 -07:00
AT_ASSERT(false); // Can't do this
}
}
step_++;
}
};
template <typename M, typename F, typename O>
bool test_mnist(
uint32_t batch_size,
uint32_t num_epochs,
bool useGPU,
M&& model,
F&& forward_op,
O&& optimizer) {
std::cout << "Training MNIST for " << num_epochs
<< " epochs, rest your eyes for a bit!\n";
struct MNIST_Reader {
FILE* fp_;
2018-05-11 18:56:53 -07:00
explicit MNIST_Reader(const char* path) {
fp_ = fopen(path, "rbe");
if (!fp_)
throw std::runtime_error("failed to open file");
}
~MNIST_Reader() {
if (fp_)
fclose(fp_);
}
2018-05-11 18:56:53 -07:00
uint32_t read_int() {
uint8_t buf[4];
2018-05-11 18:56:53 -07:00
if (fread(buf, sizeof(buf), 1, fp_) != 1) {
throw std::runtime_error("failed to read an integer");
2018-05-11 18:56:53 -07:00
}
return buf[0] << 24u | buf[1] << 16u | buf[2] << 8u | buf[3];
}
uint8_t read_byte() {
uint8_t i;
2018-05-11 18:56:53 -07:00
if (fread(&i, sizeof(i), 1, fp_) != 1) {
throw std::runtime_error("failed to read an byte");
2018-05-11 18:56:53 -07:00
}
return i;
}
};
auto readData = [&](std::string fn) {
MNIST_Reader rd(fn.c_str());
/* int image_magic = */ rd.read_int();
int image_count = rd.read_int();
int image_rows = rd.read_int();
int image_cols = rd.read_int();
Create ATen tensors via TensorOptions (#7869) * Created TensorOptions Storing the type in TensorOptions to solve the Variable problem Created convenience creation functions for TensorOptions and added tests Converted zeros to TensorOptions Converted rand to TensorOptions Fix codegen for TensorOptions and multiple arguments Put TensorOptions convenience functions into torch namespace too All factory functions except *_like support TensorOptions Integrated with recent JIT changes Support *_like functions Fix in place modification Some cleanups and fixes Support sparse_coo_tensor Fix bug in Type.cpp Fix .empty calls in C++ API Fix bug in Type.cpp Trying to fix device placement Make AutoGPU CPU compatible Remove some auto_gpu.h uses Fixing some headers Fix some remaining CUDA/AutoGPU issues Fix some AutoGPU uses Fixes to dispatch_tensor_conversion Reset version of new variables to zero Implemented parsing device strings Random fixes to tests Self review cleanups flake8 Undo changes to variable.{h,cpp} because they fail on gcc7.2 Add [cuda] tag to tensor_options_cuda.cpp Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks Fix linker error in AutoGPU.cpp Fix bad merge conflict in native_functions.yaml Fixed caffe2/contrib/aten Fix new window functions added to TensorFactories.cpp * Removed torch::TensorOptions Added code to generate wrapper functions for factory methods Add implicit constructor from Backend to TensorOptions Remove Var() from C++ API and use torch:: functions Use torch:: functions more subtly in C++ API Make AutoGPU::set_device more exception safe Check status directly in DynamicCUDAHooksInterface Rename AutoGPU to DeviceGuard Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad remove python_default_init: self.type() Add back original factory functions, but with deprecation warnings Disable DeviceGuard for a couple functions in ATen Remove print statement Fix DeviceGuard construction from undefined tensor Fixing CUDA device compiler issues Moved as many methods as possible into header files Dont generate python functions for deprecated factories Remove merge conflict artefact Fix tensor_options_cuda.cpp Fix set_requires_grad not being checked Fix tensor_new.h TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac Fix bug in DeviceGuard.h Missing includes TEMPORARILY moving a few more methods into .cpp to see if it fixes windows Fixing linker errors * Fix up SummaryOps to use new factories Undo device agnostic behavior of DeviceGuard Use -1 instead of optional for default device index Also move DeviceGuard methods into header Fixes around device index after optional -> int32_t switch Fix use of DeviceGuard in new_with_tensor_copy Fix tensor_options.cpp * Fix Type::copy( * Remove test_non_float_params from ONNX tests * Set requires_grad=False in ONNX tests that use ints * Put layout/dtype/device on Tensor * Post merge fixes * Change behavior of DeviceGuard to match AutoGPU * Fix C++ API integration tests * Fix flip functions
2018-06-16 00:40:35 -07:00
auto data = torch::empty({image_count, 1, image_rows, image_cols});
auto a_data = data.accessor<float, 4>();
for (int c = 0; c < image_count; c++) {
for (int i = 0; i < image_rows; i++) {
for (int j = 0; j < image_cols; j++) {
a_data[c][0][i][j] = float(rd.read_byte()) / 255;
}
}
}
return data.toBackend(useGPU ? torch::kCUDA : torch::kCPU);
};
auto readLabels = [&](std::string fn) {
MNIST_Reader rd(fn.c_str());
/* int label_magic = */ rd.read_int();
int label_count = rd.read_int();
auto data = torch::empty({label_count}, torch::kInt64);
auto a_data = data.accessor<int64_t, 1>();
for (int i = 0; i < label_count; ++i) {
a_data[i] = static_cast<int64_t>(rd.read_byte());
}
return data.toBackend(useGPU ? torch::kCUDA : torch::kCPU);
};
auto trdata = readData("test/cpp/api/mnist/train-images-idx3-ubyte");
auto trlabel = readLabels("test/cpp/api/mnist/train-labels-idx1-ubyte");
auto tedata = readData("test/cpp/api/mnist/t10k-images-idx3-ubyte");
auto telabel = readLabels("test/cpp/api/mnist/t10k-labels-idx1-ubyte");
if (useGPU) {
model->cuda();
}
std::random_device device;
std::mt19937 generator(device());
for (auto epoch = 0U; epoch < num_epochs; epoch++) {
auto shuffled_inds = std::vector<int>(trdata.size(0));
for (int i = 0; i < trdata.size(0); i++) {
shuffled_inds[i] = i;
}
std::shuffle(shuffled_inds.begin(), shuffled_inds.end(), generator);
const auto backend = useGPU ? torch::kCUDA : torch::kCPU;
Create ATen tensors via TensorOptions (#7869) * Created TensorOptions Storing the type in TensorOptions to solve the Variable problem Created convenience creation functions for TensorOptions and added tests Converted zeros to TensorOptions Converted rand to TensorOptions Fix codegen for TensorOptions and multiple arguments Put TensorOptions convenience functions into torch namespace too All factory functions except *_like support TensorOptions Integrated with recent JIT changes Support *_like functions Fix in place modification Some cleanups and fixes Support sparse_coo_tensor Fix bug in Type.cpp Fix .empty calls in C++ API Fix bug in Type.cpp Trying to fix device placement Make AutoGPU CPU compatible Remove some auto_gpu.h uses Fixing some headers Fix some remaining CUDA/AutoGPU issues Fix some AutoGPU uses Fixes to dispatch_tensor_conversion Reset version of new variables to zero Implemented parsing device strings Random fixes to tests Self review cleanups flake8 Undo changes to variable.{h,cpp} because they fail on gcc7.2 Add [cuda] tag to tensor_options_cuda.cpp Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks Fix linker error in AutoGPU.cpp Fix bad merge conflict in native_functions.yaml Fixed caffe2/contrib/aten Fix new window functions added to TensorFactories.cpp * Removed torch::TensorOptions Added code to generate wrapper functions for factory methods Add implicit constructor from Backend to TensorOptions Remove Var() from C++ API and use torch:: functions Use torch:: functions more subtly in C++ API Make AutoGPU::set_device more exception safe Check status directly in DynamicCUDAHooksInterface Rename AutoGPU to DeviceGuard Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad remove python_default_init: self.type() Add back original factory functions, but with deprecation warnings Disable DeviceGuard for a couple functions in ATen Remove print statement Fix DeviceGuard construction from undefined tensor Fixing CUDA device compiler issues Moved as many methods as possible into header files Dont generate python functions for deprecated factories Remove merge conflict artefact Fix tensor_options_cuda.cpp Fix set_requires_grad not being checked Fix tensor_new.h TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac Fix bug in DeviceGuard.h Missing includes TEMPORARILY moving a few more methods into .cpp to see if it fixes windows Fixing linker errors * Fix up SummaryOps to use new factories Undo device agnostic behavior of DeviceGuard Use -1 instead of optional for default device index Also move DeviceGuard methods into header Fixes around device index after optional -> int32_t switch Fix use of DeviceGuard in new_with_tensor_copy Fix tensor_options.cpp * Fix Type::copy( * Remove test_non_float_params from ONNX tests * Set requires_grad=False in ONNX tests that use ints * Put layout/dtype/device on Tensor * Post merge fixes * Change behavior of DeviceGuard to match AutoGPU * Fix C++ API integration tests * Fix flip functions
2018-06-16 00:40:35 -07:00
auto inp =
torch::empty({batch_size, 1, trdata.size(2), trdata.size(3)}, backend);
2018-06-19 17:00:28 -07:00
auto lab =
torch::empty({batch_size}, torch::device(backend).dtype(torch::kInt64));
for (auto p = 0U; p < shuffled_inds.size() - batch_size; p++) {
inp[p % batch_size] = trdata[shuffled_inds[p]];
lab[p % batch_size] = trlabel[shuffled_inds[p]];
if (p % batch_size != batch_size - 1)
continue;
Create ATen tensors via TensorOptions (#7869) * Created TensorOptions Storing the type in TensorOptions to solve the Variable problem Created convenience creation functions for TensorOptions and added tests Converted zeros to TensorOptions Converted rand to TensorOptions Fix codegen for TensorOptions and multiple arguments Put TensorOptions convenience functions into torch namespace too All factory functions except *_like support TensorOptions Integrated with recent JIT changes Support *_like functions Fix in place modification Some cleanups and fixes Support sparse_coo_tensor Fix bug in Type.cpp Fix .empty calls in C++ API Fix bug in Type.cpp Trying to fix device placement Make AutoGPU CPU compatible Remove some auto_gpu.h uses Fixing some headers Fix some remaining CUDA/AutoGPU issues Fix some AutoGPU uses Fixes to dispatch_tensor_conversion Reset version of new variables to zero Implemented parsing device strings Random fixes to tests Self review cleanups flake8 Undo changes to variable.{h,cpp} because they fail on gcc7.2 Add [cuda] tag to tensor_options_cuda.cpp Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks Fix linker error in AutoGPU.cpp Fix bad merge conflict in native_functions.yaml Fixed caffe2/contrib/aten Fix new window functions added to TensorFactories.cpp * Removed torch::TensorOptions Added code to generate wrapper functions for factory methods Add implicit constructor from Backend to TensorOptions Remove Var() from C++ API and use torch:: functions Use torch:: functions more subtly in C++ API Make AutoGPU::set_device more exception safe Check status directly in DynamicCUDAHooksInterface Rename AutoGPU to DeviceGuard Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad remove python_default_init: self.type() Add back original factory functions, but with deprecation warnings Disable DeviceGuard for a couple functions in ATen Remove print statement Fix DeviceGuard construction from undefined tensor Fixing CUDA device compiler issues Moved as many methods as possible into header files Dont generate python functions for deprecated factories Remove merge conflict artefact Fix tensor_options_cuda.cpp Fix set_requires_grad not being checked Fix tensor_new.h TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac Fix bug in DeviceGuard.h Missing includes TEMPORARILY moving a few more methods into .cpp to see if it fixes windows Fixing linker errors * Fix up SummaryOps to use new factories Undo device agnostic behavior of DeviceGuard Use -1 instead of optional for default device index Also move DeviceGuard methods into header Fixes around device index after optional -> int32_t switch Fix use of DeviceGuard in new_with_tensor_copy Fix tensor_options.cpp * Fix Type::copy( * Remove test_non_float_params from ONNX tests * Set requires_grad=False in ONNX tests that use ints * Put layout/dtype/device on Tensor * Post merge fixes * Change behavior of DeviceGuard to match AutoGPU * Fix C++ API integration tests * Fix flip functions
2018-06-16 00:40:35 -07:00
inp.set_requires_grad(true);
torch::Tensor x = forward_op(inp);
Create ATen tensors via TensorOptions (#7869) * Created TensorOptions Storing the type in TensorOptions to solve the Variable problem Created convenience creation functions for TensorOptions and added tests Converted zeros to TensorOptions Converted rand to TensorOptions Fix codegen for TensorOptions and multiple arguments Put TensorOptions convenience functions into torch namespace too All factory functions except *_like support TensorOptions Integrated with recent JIT changes Support *_like functions Fix in place modification Some cleanups and fixes Support sparse_coo_tensor Fix bug in Type.cpp Fix .empty calls in C++ API Fix bug in Type.cpp Trying to fix device placement Make AutoGPU CPU compatible Remove some auto_gpu.h uses Fixing some headers Fix some remaining CUDA/AutoGPU issues Fix some AutoGPU uses Fixes to dispatch_tensor_conversion Reset version of new variables to zero Implemented parsing device strings Random fixes to tests Self review cleanups flake8 Undo changes to variable.{h,cpp} because they fail on gcc7.2 Add [cuda] tag to tensor_options_cuda.cpp Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks Fix linker error in AutoGPU.cpp Fix bad merge conflict in native_functions.yaml Fixed caffe2/contrib/aten Fix new window functions added to TensorFactories.cpp * Removed torch::TensorOptions Added code to generate wrapper functions for factory methods Add implicit constructor from Backend to TensorOptions Remove Var() from C++ API and use torch:: functions Use torch:: functions more subtly in C++ API Make AutoGPU::set_device more exception safe Check status directly in DynamicCUDAHooksInterface Rename AutoGPU to DeviceGuard Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad remove python_default_init: self.type() Add back original factory functions, but with deprecation warnings Disable DeviceGuard for a couple functions in ATen Remove print statement Fix DeviceGuard construction from undefined tensor Fixing CUDA device compiler issues Moved as many methods as possible into header files Dont generate python functions for deprecated factories Remove merge conflict artefact Fix tensor_options_cuda.cpp Fix set_requires_grad not being checked Fix tensor_new.h TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac Fix bug in DeviceGuard.h Missing includes TEMPORARILY moving a few more methods into .cpp to see if it fixes windows Fixing linker errors * Fix up SummaryOps to use new factories Undo device agnostic behavior of DeviceGuard Use -1 instead of optional for default device index Also move DeviceGuard methods into header Fixes around device index after optional -> int32_t switch Fix use of DeviceGuard in new_with_tensor_copy Fix tensor_options.cpp * Fix Type::copy( * Remove test_non_float_params from ONNX tests * Set requires_grad=False in ONNX tests that use ints * Put layout/dtype/device on Tensor * Post merge fixes * Change behavior of DeviceGuard to match AutoGPU * Fix C++ API integration tests * Fix flip functions
2018-06-16 00:40:35 -07:00
inp.set_requires_grad(false);
torch::Tensor y = lab;
torch::Tensor loss = torch::nll_loss(x, y);
optimizer.zero_grad();
loss.backward();
optimizer.step();
}
}
torch::NoGradGuard guard;
Create ATen tensors via TensorOptions (#7869) * Created TensorOptions Storing the type in TensorOptions to solve the Variable problem Created convenience creation functions for TensorOptions and added tests Converted zeros to TensorOptions Converted rand to TensorOptions Fix codegen for TensorOptions and multiple arguments Put TensorOptions convenience functions into torch namespace too All factory functions except *_like support TensorOptions Integrated with recent JIT changes Support *_like functions Fix in place modification Some cleanups and fixes Support sparse_coo_tensor Fix bug in Type.cpp Fix .empty calls in C++ API Fix bug in Type.cpp Trying to fix device placement Make AutoGPU CPU compatible Remove some auto_gpu.h uses Fixing some headers Fix some remaining CUDA/AutoGPU issues Fix some AutoGPU uses Fixes to dispatch_tensor_conversion Reset version of new variables to zero Implemented parsing device strings Random fixes to tests Self review cleanups flake8 Undo changes to variable.{h,cpp} because they fail on gcc7.2 Add [cuda] tag to tensor_options_cuda.cpp Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks Fix linker error in AutoGPU.cpp Fix bad merge conflict in native_functions.yaml Fixed caffe2/contrib/aten Fix new window functions added to TensorFactories.cpp * Removed torch::TensorOptions Added code to generate wrapper functions for factory methods Add implicit constructor from Backend to TensorOptions Remove Var() from C++ API and use torch:: functions Use torch:: functions more subtly in C++ API Make AutoGPU::set_device more exception safe Check status directly in DynamicCUDAHooksInterface Rename AutoGPU to DeviceGuard Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad remove python_default_init: self.type() Add back original factory functions, but with deprecation warnings Disable DeviceGuard for a couple functions in ATen Remove print statement Fix DeviceGuard construction from undefined tensor Fixing CUDA device compiler issues Moved as many methods as possible into header files Dont generate python functions for deprecated factories Remove merge conflict artefact Fix tensor_options_cuda.cpp Fix set_requires_grad not being checked Fix tensor_new.h TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac Fix bug in DeviceGuard.h Missing includes TEMPORARILY moving a few more methods into .cpp to see if it fixes windows Fixing linker errors * Fix up SummaryOps to use new factories Undo device agnostic behavior of DeviceGuard Use -1 instead of optional for default device index Also move DeviceGuard methods into header Fixes around device index after optional -> int32_t switch Fix use of DeviceGuard in new_with_tensor_copy Fix tensor_options.cpp * Fix Type::copy( * Remove test_non_float_params from ONNX tests * Set requires_grad=False in ONNX tests that use ints * Put layout/dtype/device on Tensor * Post merge fixes * Change behavior of DeviceGuard to match AutoGPU * Fix C++ API integration tests * Fix flip functions
2018-06-16 00:40:35 -07:00
auto result = std::get<1>(forward_op(tedata).max(1));
torch::Tensor correct = (result == telabel).toType(torch::kFloat32);
std::cout << "Num correct: " << correct.data().sum().toCFloat() << " out of "
<< telabel.size(0) << std::endl;
return correct.data().sum().toCFloat() > telabel.size(0) * 0.8;
}
TEST_CASE("integration/cartpole") {
std::cerr << "Training episodic policy gradient with a critic for up to 3000"
" episodes, rest your eyes for a bit!\n";
auto model = std::make_shared<torch::SimpleContainer>();
auto linear = model->add(Linear(4, 128), "linear");
auto policyHead = model->add(Linear(128, 2), "policy");
auto valueHead = model->add(Linear(128, 1), "action");
auto optimizer = torch::optim::Adam(model->parameters(), 1e-3);
std::vector<torch::Tensor> saved_log_probs;
std::vector<torch::Tensor> saved_values;
std::vector<float> rewards;
auto forward = [&](torch::Tensor inp) {
auto x = linear->forward(inp).clamp_min(0);
torch::Tensor actions = policyHead->forward(x);
torch::Tensor value = valueHead->forward(x);
return std::make_tuple(torch::softmax(actions, -1), value);
};
auto selectAction = [&](torch::Tensor state) {
// Only work on single state right now, change index to gather for batch
auto out = forward(state);
auto probs = torch::Tensor(std::get<0>(out));
auto value = torch::Tensor(std::get<1>(out));
auto action = probs.data().multinomial(1)[0].toCInt();
// Compute the log prob of a multinomial distribution.
// This should probably be actually implemented in autogradpp...
auto p = probs / probs.sum(-1, true);
auto log_prob = p[action].log();
saved_log_probs.emplace_back(log_prob);
saved_values.push_back(value);
return action;
};
auto finishEpisode = [&]() {
auto R = 0.;
for (int i = rewards.size() - 1; i >= 0; i--) {
R = rewards[i] + 0.99 * R;
rewards[i] = R;
}
auto r_t =
torch::from_blob(rewards.data(), {static_cast<int64_t>(rewards.size())});
r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-5);
std::vector<torch::Tensor> policy_loss;
std::vector<torch::Tensor> value_loss;
for (auto i = 0U; i < saved_log_probs.size(); i++) {
auto r = rewards[i] - saved_values[i].toCFloat();
policy_loss.push_back(-r * saved_log_probs[i]);
value_loss.push_back(
torch::smooth_l1_loss(saved_values[i], torch::ones({1}) * rewards[i]));
}
auto loss = torch::stack(torch::TensorListView(policy_loss)).sum() +
torch::stack(torch::TensorListView(value_loss)).sum();
optimizer.zero_grad();
loss.backward();
optimizer.step();
rewards.clear();
saved_log_probs.clear();
saved_values.clear();
};
auto env = CartPole();
double running_reward = 10.0;
for (auto episode = 0;; episode++) {
env.reset();
auto state = env.getState();
int t = 0;
for (; t < 10000; t++) {
auto action = selectAction(state);
env.step(action);
state = env.getState();
auto reward = env.getReward();
auto done = env.isDone();
rewards.push_back(reward);
if (done)
break;
}
running_reward = running_reward * 0.99 + t * 0.01;
finishEpisode();
/*
if (episode % 10 == 0) {
printf("Episode %i\tLast length: %5d\tAverage length: %.2f\n",
episode, t, running_reward);
}
*/
if (running_reward > 150)
break;
REQUIRE(episode < 3000);
}
}
TEST_CASE("integration/mnist", "[cuda]") {
auto model = std::make_shared<torch::SimpleContainer>();
auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
auto drop = Dropout(0.3);
auto drop2d = Dropout2d(0.3);
auto linear1 = model->add(Linear(320, 50), "linear1");
auto linear2 = model->add(Linear(50, 10), "linear2");
auto forward = [&](torch::Tensor x) {
x = std::get<0>(torch::max_pool2d(conv1->forward(x), {2, 2})).clamp_min(0);
x = conv2->forward(x);
x = drop2d->forward(x);
x = std::get<0>(torch::max_pool2d(x, {2, 2})).clamp_min(0);
x = x.view({-1, 320});
x = linear1->forward(x).clamp_min(0);
x = drop->forward(x);
x = linear2->forward(x);
x = torch::log_softmax(x, 1);
return x;
};
auto optimizer = torch::optim::SGD(
model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
REQUIRE(test_mnist(
32, // batch_size
3, // num_epochs
true, // useGPU
model,
forward,
optimizer));
}
TEST_CASE("integration/mnist/batchnorm", "[cuda]") {
auto model = std::make_shared<torch::SimpleContainer>();
auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
2018-06-19 17:00:28 -07:00
auto batchnorm2d =
model->add(BatchNorm(BatchNormOptions(10).stateful(true)), "batchnorm2d");
auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
auto linear1 = model->add(Linear(320, 50), "linear1");
2018-06-19 17:00:28 -07:00
auto batchnorm1 =
model->add(BatchNorm(BatchNormOptions(50).stateful(true)), "batchnorm1");
auto linear2 = model->add(Linear(50, 10), "linear2");
auto forward = [&](torch::Tensor x) {
x = std::get<0>(torch::max_pool2d(conv1->forward(x), {2, 2})).clamp_min(0);
x = batchnorm2d->forward(x);
x = conv2->forward(x);
x = std::get<0>(torch::max_pool2d(x, {2, 2})).clamp_min(0);
x = x.view({-1, 320});
x = linear1->forward(x).clamp_min(0);
x = batchnorm1->forward(x);
x = linear2->forward(x);
x = torch::log_softmax(x, 1);
return x;
};
auto optimizer = torch::optim::SGD(
model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
REQUIRE(test_mnist(
32, // batch_size
3, // num_epochs
true, // useGPU
model,
forward,
optimizer));
}