3
3
{-# LANGUAGE GADTs #-}
4
4
{-# LANGUAGE PartialTypeSignatures #-}
5
5
{-# LANGUAGE PatternSynonyms #-}
6
+ {-# LANGUAGE RankNTypes #-}
6
7
{-# LANGUAGE RecordWildCards #-}
7
8
{-# LANGUAGE ScopedTypeVariables #-}
8
9
{-# LANGUAGE TypeApplications #-}
9
10
{-# LANGUAGE TypeFamilies #-}
10
11
{-# LANGUAGE TypeOperators #-}
11
- {-# LANGUAGE RankNTypes #-}
12
12
13
13
module Torch.GraduallyTyped.NN.Transformer.Generation where
14
14
15
- import Control.Monad.State (MonadState , StateT (.. ), get , put , evalStateT )
16
- import Control.Lens (Lens , Traversal , Lens' , (^.) , (%~) )
15
+ import Control.Lens (Lens )
17
16
import Control.Monad.Catch (MonadThrow )
17
+ import Control.Monad.State (MonadState (.. ), get , put )
18
18
import Data.Function (fix )
19
- import Data.Singletons.Prelude.List (SList (SNil ))
20
19
import Foreign.ForeignPtr (ForeignPtr )
21
- import Torch.GraduallyTyped.DType (DType (.. ), DataType (.. ), SDType ( .. ), SDataType ( .. ) )
20
+ import Torch.GraduallyTyped.DType (DType (.. ), DataType (.. ))
22
21
import Torch.GraduallyTyped.Index.Type (Index (NegativeIndex ), SIndex (.. ))
23
- import Torch.GraduallyTyped.NN.Class (HasForward (.. ), stateDictFromFile , HasStateDict (.. ))
24
- import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (SimplifiedEncoderDecoderTransformerGenerationInput (.. ), SimplifiedEncoderDecoderTransformerOutput (.. ), SimplifiedEncoderDecoderTransformerOutput' (.. ), SimplifiedEncoderDecoderTransformerInput' (.. ))
22
+ import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (SimplifiedEncoderDecoderTransformerGenerationInput (.. ), SimplifiedEncoderDecoderTransformerOutput (.. ))
25
23
import Torch.GraduallyTyped.Prelude (Catch , pattern (:|:) )
26
- import Torch.GraduallyTyped.Random ( Generator )
24
+ import Torch.GraduallyTyped.Prelude.List ( SList ( SNil ) )
27
25
import Torch.GraduallyTyped.RequiresGradient (Gradient (.. ), RequiresGradient (.. ), SGradient (.. ), SRequiresGradient (.. ))
28
26
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF )
29
- import Torch.GraduallyTyped.Shape.Type (By (.. ), SBy (.. ), SSelectDim (.. ), SelectDim (.. ), Shape (.. ), SShape ( .. ), pattern SNoName , pattern (:&:) , SSize ( .. ) )
27
+ import Torch.GraduallyTyped.Shape.Type (By (.. ), SBy (.. ), SSelectDim (.. ), SelectDim (.. ), Shape (.. ))
30
28
import Torch.GraduallyTyped.Tensor.Indexing (IndexDims , IndexType (.. ), Indices (.. ), SIndexType (.. ), SIndices (.. ), (!) )
31
29
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (CatHListF , HasCat (.. ), SqueezeDimF , UnsqueezeF , sSqueezeDim , sUnsqueeze )
32
30
import Torch.GraduallyTyped.Tensor.MathOperations.Comparison ((/=.) , (==.) )
33
31
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (mul , mulScalar , sub , subScalar )
34
- import Torch.GraduallyTyped.Tensor.MathOperations.Reduction (ArgmaxF , all , argmax , sAllDim , maxAll , MaxAllCheckF )
35
- import Torch.GraduallyTyped.Tensor.Type (TensorSpec (.. ), SGetDataType (.. ), SGetDevice (.. ), SGetLayout (.. ), Tensor , TensorLike (.. ), sSetDataType , SGetShape (.. ), sCheckedShape , SGetDim )
36
- import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead (.. ), mkTransformerInput )
32
+ import Torch.GraduallyTyped.Tensor.MathOperations.Reduction (ArgmaxF , MaxAllCheckF , argmax , maxAll )
33
+ import Torch.GraduallyTyped.Tensor.Type (SGetDataType (.. ), SGetDevice (.. ), SGetDim , SGetLayout (.. ), SGetShape (.. ), Tensor , TensorLike (.. ), sCheckedShape , sSetDataType )
37
34
import Torch.GraduallyTyped.Unify (type (<+> ), type (<|> ))
38
- import Torch.GraduallyTyped.Device (SDevice (.. ), SDeviceType (.. ))
39
- import Torch.GraduallyTyped.NN.Transformer.BART.Common (bartPadTokenId , bartEOSTokenId )
40
- import Torch.GraduallyTyped.NN.Transformer.BART.Base (bartBaseSpec )
41
- import Torch.GraduallyTyped.Tensor.Creation (sOnes )
42
- import Torch.GraduallyTyped.Layout (SLayout (.. ), SLayoutType (.. ))
43
- import Torch.GraduallyTyped.Random (Generator , sMkGenerator )
44
- import Torch.GraduallyTyped.NN.Type (SHasDropout (.. ))
45
35
import Torch.HList (HList (HNil ), pattern (:.) )
46
- import qualified Torch.Internal.Class as ATen ( Castable )
36
+ import qualified Torch.Internal.Class as ATen
47
37
import qualified Torch.Internal.Type as ATen
48
38
import Prelude hiding (all )
49
- import Control.Monad ((<=<) , (>=>) )
50
- import Control.Monad.State (MonadState (.. ))
51
- import Torch.GraduallyTyped.Layout (Layout (Layout ), LayoutType (Dense ))
52
- import qualified Tokenizers
53
39
54
40
decode ::
55
41
Monad m =>
@@ -59,10 +45,10 @@ decode ::
59
45
m (x , s )
60
46
decode f x s = do
61
47
flip fix (x, s) $ \ loop (x', s') -> do
62
- r <- f x' s'
63
- case r of
64
- Nothing -> pure (x', s')
65
- Just (x'', s'') -> loop (x'', s'')
48
+ r <- f x' s'
49
+ case r of
50
+ Nothing -> pure (x', s')
51
+ Just (x'', s'') -> loop (x'', s'')
66
52
67
53
sedtOutputToInput ::
68
54
Monad m =>
@@ -117,7 +103,7 @@ greedyNextTokens ::
117
103
SGetDevice logitsDevice ,
118
104
SGetLayout logitsLayout
119
105
) =>
120
- Int ->
106
+ Int ->
121
107
Int ->
122
108
Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape ->
123
109
m (Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape )
0 commit comments