Skip to content

Commit

Permalink
Print ragged tensors in a way like what PyTorch is doing.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Sep 15, 2021
1 parent 47a5cca commit 1dd33d1
Show file tree
Hide file tree
Showing 5 changed files with 638 additions and 167 deletions.
2 changes: 2 additions & 0 deletions k2/csrc/ragged_ops_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ std::istream &operator>>(std::istream &is, Ragged<T> &r) {
: (row_splits[cur_level + 1].size() - 1));
is.get(); // consume character 'c'
if (cur_level == 0) break;
} else if (c == ',') {
is.get(); // consume character 'c'
} else {
InputFixer<T> t;
is >> t;
Expand Down
63 changes: 45 additions & 18 deletions k2/python/csrc/torch/v2/any.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,24 @@ void PybindRaggedAny(py::module &m) {
// k2.ragged.Tensor methods
//--------------------------------------------------

any.def(
py::init([](py::list data,
py::object dtype = py::none()) -> std::unique_ptr<RaggedAny> {
return std::make_unique<RaggedAny>(data, dtype);
}),
py::arg("data"), py::arg("dtype") = py::none(), kRaggedAnyInitDataDoc);
any.def(py::init<py::list, py::object, torch::Device>(), py::arg("data"),
py::arg("dtype") = py::none(),
py::arg("device") = torch::Device(torch::kCPU),
kRaggedAnyInitDataDeviceDoc);

any.def(
py::init([](const std::string &s,
py::object dtype = py::none()) -> std::unique_ptr<RaggedAny> {
return std::make_unique<RaggedAny>(s, dtype);
}),
py::arg("s"), py::arg("dtype") = py::none(), kRaggedAnyInitStrDoc);
any.def(py::init<py::list, py::object, const std::string &>(),
py::arg("data"), py::arg("dtype") = py::none(),
py::arg("device") = "cpu", kRaggedAnyInitDataDeviceDoc);

any.def(py::init<const std::string &, py::object, torch::Device>(),
py::arg("s"), py::arg("dtype") = py::none(),
py::arg("device") = torch::Device(torch::kCPU),
kRaggedAnyInitStrDeviceDoc);

any.def(py::init<const std::string &, py::object, const std::string &>(),
py::arg("s"), py::arg("dtype") = py::none(),
py::arg("device") = torch::Device(torch::kCPU),
kRaggedAnyInitStrDeviceDoc);

any.def(py::init<const RaggedShape &, torch::Tensor>(), py::arg("shape"),
py::arg("value"), kRaggedInitFromShapeAndTensorDoc);
Expand Down Expand Up @@ -408,21 +413,43 @@ void PybindRaggedAny(py::module &m) {
// _k2.ragged.functions
//--------------------------------------------------

// TODO: change the function name from "create_tensor" to "tensor"
m.def(
"create_ragged_tensor",
[](py::list data, py::object dtype = py::none()) -> RaggedAny {
return RaggedAny(data, dtype);
[](py::list data, py::object dtype = py::none(),
torch::Device device = torch::kCPU) -> RaggedAny {
return RaggedAny(data, dtype, device);
},
py::arg("data"), py::arg("dtype") = py::none(),
py::arg("device") = torch::Device(torch::kCPU),
kCreateRaggedTensorDataDoc);

m.def(
"create_ragged_tensor",
[](const std::string &s, py::object dtype = py::none()) -> RaggedAny {
return RaggedAny(s, dtype);
[](py::list data, py::object dtype = py::none(),
const std::string &device = "cpu") -> RaggedAny {
return RaggedAny(data, dtype, device);
},
py::arg("data"), py::arg("dtype") = py::none(), py::arg("device") = "cpu",
kCreateRaggedTensorDataDoc);

m.def(
"create_ragged_tensor",
[](const std::string &s, py::object dtype = py::none(),
torch::Device device = torch::kCPU) -> RaggedAny {
return RaggedAny(s, dtype, device);
},
py::arg("s"), py::arg("dtype") = py::none(),
py::arg("device") = torch::Device(torch::kCPU),
kCreateRaggedTensorStrDoc);

m.def(
"create_ragged_tensor",
[](const std::string &s, py::object dtype = py::none(),
const std::string &device = "cpu") -> RaggedAny {
return RaggedAny(s, dtype, device);
},
py::arg("s"), py::arg("dtype") = py::none(), kCreateRaggedTensorStrDoc);
py::arg("s"), py::arg("dtype") = py::none(), py::arg("device") = "cpu",
kCreateRaggedTensorStrDoc);

m.def(
"create_ragged_tensor",
Expand Down
Loading

0 comments on commit 1dd33d1

Please sign in to comment.