-
Notifications
You must be signed in to change notification settings - Fork 605
Add NumToTensor #393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NumToTensor #393
Conversation
c7586ce
to
8507ce3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Some NFC.
@@ -412,6 +412,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> { | |||
return visitAtenMatmulOp(matmul, operands); | |||
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) { | |||
return visitAtenSoftmaxIntOp(softmaxIntOp, operands); | |||
} else if (auto nts = dyn_cast<PrimNumToTensorScalarOp>(op)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: I think we should follow the naming convention, i.e., name it as numToTensorOp
rather than nts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is not addressed yet.
I've created a PR #395 with a simple script you can use to test your implementation e2e. Just replace the graph_str = """\
graph(%p0 : Tensor, %p1 : Tensor):
%0 : float = prim::Constant[value=0.123]()
%1 : Tensor = prim::NumToTensor(%0)
return (%1)
""" You can also test different input types by changing the
|
You should be able to look at the source code location in the IR to find the line in the Python code that results in this op being created and create a Python e2e test. |
@silvasean, unfortunately there is no source location on this graph being used for the Bert ops https://gist.github.com/cathyzhyi/8b9e3ec28f66df2c119795b13abcc1f8 |
Hmmm... ok. In that case, it is more difficult, but I think the effort to discover a Python test case is valuable. I just searched for "NumToTensor" in PyTorch and found that something like |
that might need to be guarded with |
Ha, I knew this sounded familiar, we actually have it in one of test cases already :)
|
I agree I just thought it would make debugging much easier if people with ops that are difficult to get information on have a simple way to quickly test their implementations locally |
Hopefully we don't have too many ops like this (they are pretty special). Let's see if it becomes a problem. |
8507ce3
to
c971fc3
Compare
Okay, I've addressed the issues. Should I be writing a new test case, or does the one that Sean pointed out suffice? |
auto knowledge = | ||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); | ||
knowledge.hasSizes = true; | ||
knowledge.dtype = getDefaultDtypeForTorchScalar(op->getResultTypes()[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should use the scalar type op.a().getType()
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refer to https://github.com/llvm/torch-mlir/blob/main/test/Dialect/Torch/refine-types.mlir to add unit test for refine type.
@@ -412,6 +412,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> { | |||
return visitAtenMatmulOp(matmul, operands); | |||
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) { | |||
return visitAtenSoftmaxIntOp(softmaxIntOp, operands); | |||
} else if (auto nts = dyn_cast<PrimNumToTensorScalarOp>(op)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is not addressed yet.
I think you should still add an e2e test. From what I can tell, the test Sean pointed out does not perform any passes on the MLIR, so your code for |
c3a5135
to
0e775f5
Compare
e2e_testing/torchscript/basic.py
Outdated
|
||
@register_test_case(module_factory=lambda: NumToTensorModule()) | ||
def NumToTensorModule_basic(module, tu: TestUtils): | ||
prim_NumToTensor(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you add a newline here? github is showing a red sign because no newline at the end of the file.
test/Dialect/Torch/refine-types.mlir
Outdated
func @aten_prim_numtotensor(%arg0: !torch.int) -> !torch.tensor { | ||
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor | ||
return %0: !torch.tensor | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: also need a new line here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gpetters94 reminder: this is not addressed yet :)
e2e_testing/torchscript/basic.py
Outdated
return _to_tensor(1) | ||
|
||
@torch.jit.script | ||
def prim_NumToTensor(i: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for having this extra function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't able to properly guard the forward
function, so this is a hackey way to get it to run.
I just realized something that applies here and might be useful for the future. This is based on the comment Chilli posted on the discord (https://discord.com/channels/636084430946959380/742573221882364009/904914843683942431) I was looking into how this is implemented, and https://github.com/pytorch/pytorch/blob/master/torch/_ops.py#L32-L51 It works for any namespace! This means that any op that we cannot find, we can just use the JITOperator name to access it. In this case, you can do torch.ops.prim.NumToTensor(some_number) It also extracts some info about the op that can be accessed with > help(torch.ops.prim.NumToTensor)
Help on built-in function NumToTensor:
NumToTensor(...) method of builtins.PyCapsule instance
NumToTensor(*args, **kwargs) -> object
Automatically bound operator 'prim::NumToTensor' with schema(s):
prim::NumToTensor.Scalar(Scalar a) -> (Tensor)
prim::NumToTensor.bool(bool a) -> (Tensor) Edit: Of course, |
Actually, if you do use |
Yeah, that's definitely nicer. I'll refactor right now. |
0e775f5
to
44d5bd1
Compare
😮 nice find!!! I didn't know it covered all non-c10 jit operators as well! |
44d5bd1
to
0c67ccd
Compare
// CHECK: %[[OUT:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<i64> -> !torch.vtensor<[],si64> | ||
// CHECK: return %[[OUT]] : !torch.vtensor<[],si64> | ||
|
||
func @ntt(%arg0: !torch.int) -> !torch.vtensor<[],si64> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sry that I missed this in the previous review. It's better to name the function torch.prim.NumToTensor.Scalar
. ntt
is not a common acronym.
// ----- | ||
|
||
// CHECK: func @ntt(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> { | ||
// CHECK: %[[A:.*]] = torch_c.to_i64 %[[IN]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to replace the A
, B
, C
to more meaningful names like CST, INIT, FILL
test/Dialect/Torch/refine-types.mlir
Outdated
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor | ||
// CHECK: return %[[CAST]] : !torch.tensor | ||
|
||
func @aten_prim_numtotensor(%arg0: !torch.int) -> !torch.tensor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The convention here is to name the function after the op, so @torch.prim.NumToTensor.Scalar
. (note: the test matmul
before this one breaks the convention)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gpetters94 Could you address Ramiro's comment here? It would be great if you could fix the above tests as well. The alignments of CHECK are also off starting from aten_matmul_broadcast_matrix
on line 982. could you help to fix that as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. What do you mean by CHECK being misaligned? It looks fine to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant the space between CHECK tag and the texts to check against.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does the formatting look in the most recent commit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks better! Thanks!
4dc8879
to
e7dc327
Compare
e7dc327
to
5b0d719
Compare
…ops (llvm#393) * Define krnl.permute op. * Support krnl.permute operation. * Properly remove loop references. * Re-push, Github was down. * Need to debug interpretOp error. * Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code. * Introduce permute, unroll operations. * More debug. * Remove std::set. * krnl.terminate fails to be converted. * Pass all tests, need to add legal ops as well as part of the conversion target. * Change test format to new permute spec. * Bug fix for nested iterate op lowering. * Simplify error reporting. * Fix compilation error. * Increase comments coverage. * Remove unnecessary imports. * Re-trigger Jenkins * Add permute/unroll tests. * Retrigger Jenkins * simplified interface for gen code, plus start of gemm shape inference Signed-off-by: Alexandre Eichenberger <[email protected]> * gemm working but tested only for simple things Signed-off-by: Alexandre Eichenberger <[email protected]> * gemm with testing Signed-off-by: Alexandre Eichenberger <[email protected]> * added the store zero in the test and code Signed-off-by: Alexandre Eichenberger <[email protected]> * added the store zero in the test and code Signed-off-by: Alexandre Eichenberger <[email protected]> * re-enabled gemm for dyn tests Signed-off-by: Alexandre Eichenberger <[email protected]> * re-enable matmul dyn tests, currently 3d & 4d broken Signed-off-by: Alexandre Eichenberger <[email protected]> * first version that works for lowering Signed-off-by: Alexandre Eichenberger <[email protected]> * fixed bug and actualized the tests Signed-off-by: Alexandre Eichenberger <[email protected]> * enabled tests for runtime Signed-off-by: Alexandre Eichenberger <[email protected]> * removed debug Signed-off-by: Alexandre Eichenberger <[email protected]> Co-authored-by: Tian Jin <[email protected]>
This looks to be working, but this is the first time I've had to write a test case in mlir.