-
Notifications
You must be signed in to change notification settings - Fork 549
Convert CUDA JIT to use nvrtc instead of nvvm #1836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8d6f981
225a3e0
81d74af
bd5667a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,68 +20,34 @@ namespace JIT | |
| class BinaryNode : public Node | ||
| { | ||
| private: | ||
| const std::string m_op_str; | ||
| const int m_op; | ||
| const int m_call_type; | ||
| std::string m_op_str; | ||
| int m_op; | ||
|
|
||
| public: | ||
| BinaryNode(const char *out_type_str, const char *name_str, | ||
| const std::string &op_str, | ||
| Node_ptr lhs, Node_ptr rhs, int op, int call_type) | ||
| const char *op_str, | ||
| Node_ptr lhs, Node_ptr rhs, int op) | ||
| : Node(out_type_str, name_str, std::max(lhs->getHeight(), rhs->getHeight()) + 1, {lhs, rhs}), | ||
| m_op_str(op_str), | ||
| m_op(op), | ||
| m_call_type(call_type) | ||
| m_op(op) | ||
| { | ||
| } | ||
|
|
||
| void genKerName(std::stringstream &kerStream, Node_ids ids) | ||
| { | ||
| // Make the hex representation of enum part of the Kernel name | ||
| kerStream << "_" << std::setw(2) << std::setfill('0') << std::hex << m_op; | ||
| kerStream << std::setw(2) << std::setfill('0') << std::hex << ids.child_ids[0]; | ||
| kerStream << std::setw(2) << std::setfill('0') << std::hex << ids.child_ids[1]; | ||
| kerStream << std::setw(2) << std::setfill('0') << std::hex << ids.id << std::dec; | ||
| // Make the dec representation of enum part of the Kernel name | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These aren't templated so it would be better if they were implemented in the cpp file.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll do another pass about reorganizing the JIT nodes at a later point (including trying to use the same code for CUDA and OpenCL JIT). Can we leave this be for now? |
||
| kerStream << "_" << std::setw(3) << std::setfill('0') << std::dec << m_op; | ||
| kerStream << std::setw(3) << std::setfill('0') << std::dec << ids.child_ids[0]; | ||
| kerStream << std::setw(3) << std::setfill('0') << std::dec << ids.child_ids[1]; | ||
| kerStream << std::setw(3) << std::setfill('0') << std::dec << ids.id << std::dec; | ||
| } | ||
|
|
||
| void genFuncs(std::stringstream &kerStream, str_map_t &declStrs, Node_ids ids, bool is_linear) | ||
| void genFuncs(std::stringstream &kerStream, Node_ids ids) | ||
| { | ||
| if (m_call_type == 0) { | ||
| std::stringstream declStream; | ||
| declStream << "declare " << m_type_str << " " << m_op_str | ||
| << "(" << m_children[0]->getTypeStr() << " , " | ||
| << m_children[1]->getTypeStr() << ")\n"; | ||
| declStrs[declStream.str()] = true; | ||
|
|
||
| kerStream << "%val" << ids.id << " = call " | ||
| << m_type_str << " " | ||
| << m_op_str << "(" | ||
| << m_children[0]->getTypeStr() << " " | ||
| << "%val" << ids.child_ids[0] << ", " | ||
| << m_children[1]->getTypeStr() << " " | ||
| << "%val" << ids.child_ids[1] << ")\n"; | ||
|
|
||
| } else { | ||
| if (m_call_type == 1) { | ||
| // arithmetic operations | ||
| kerStream << "%val" << ids.id << " = " | ||
| << m_op_str << " " | ||
| << m_type_str << " " | ||
| << "%val" << ids.child_ids[0] << ", " | ||
| << "%val" << ids.child_ids[1] << "\n"; | ||
| } else { | ||
| // logical operators | ||
| kerStream << "%tmp" << ids.id << " = " | ||
| << m_op_str << " " | ||
| << m_children[0]->getTypeStr() << " " | ||
| << "%val" << ids.child_ids[0] << ", " | ||
| << "%val" << ids.child_ids[1] << "\n"; | ||
|
|
||
| kerStream << "%val" << ids.id << " = " | ||
| << "zext i1 %tmp" << ids.id << " to i8\n"; | ||
|
|
||
| } | ||
| } | ||
| kerStream << m_type_str << " val" << ids.id << " = " | ||
| << m_op_str << "(val" << ids.child_ids[0] | ||
| << ", val" << ids.child_ids[1] << ");" | ||
| << "\n"; | ||
| } | ||
| }; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should change the name of this guy at some point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not doing it now :(