Skip to content

Commit 0d1a39a

Browse files
committed
clean up imports
1 parent 744dc4a commit 0d1a39a

File tree

1 file changed

+15
-29
lines changed
  • experimental/gradually-typed/src/Torch/GraduallyTyped/NN/Transformer

1 file changed

+15
-29
lines changed

experimental/gradually-typed/src/Torch/GraduallyTyped/NN/Transformer/Generation.hs

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,53 +3,39 @@
33
{-# LANGUAGE GADTs #-}
44
{-# LANGUAGE PartialTypeSignatures #-}
55
{-# LANGUAGE PatternSynonyms #-}
6+
{-# LANGUAGE RankNTypes #-}
67
{-# LANGUAGE RecordWildCards #-}
78
{-# LANGUAGE ScopedTypeVariables #-}
89
{-# LANGUAGE TypeApplications #-}
910
{-# LANGUAGE TypeFamilies #-}
1011
{-# LANGUAGE TypeOperators #-}
11-
{-# LANGUAGE RankNTypes #-}
1212

1313
module Torch.GraduallyTyped.NN.Transformer.Generation where
1414

15-
import Control.Monad.State (MonadState, StateT (..), get, put, evalStateT)
16-
import Control.Lens (Lens, Traversal, Lens', (^.), (%~))
15+
import Control.Lens (Lens)
1716
import Control.Monad.Catch (MonadThrow)
17+
import Control.Monad.State (MonadState (..), get, put)
1818
import Data.Function (fix)
19-
import Data.Singletons.Prelude.List (SList (SNil))
2019
import Foreign.ForeignPtr (ForeignPtr)
21-
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDType (..), SDataType (..))
20+
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
2221
import Torch.GraduallyTyped.Index.Type (Index (NegativeIndex), SIndex (..))
23-
import Torch.GraduallyTyped.NN.Class (HasForward (..), stateDictFromFile, HasStateDict (..))
24-
import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (SimplifiedEncoderDecoderTransformerGenerationInput (..), SimplifiedEncoderDecoderTransformerOutput (..), SimplifiedEncoderDecoderTransformerOutput' (..), SimplifiedEncoderDecoderTransformerInput' (..))
22+
import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (SimplifiedEncoderDecoderTransformerGenerationInput (..), SimplifiedEncoderDecoderTransformerOutput (..))
2523
import Torch.GraduallyTyped.Prelude (Catch, pattern (:|:))
26-
import Torch.GraduallyTyped.Random (Generator)
24+
import Torch.GraduallyTyped.Prelude.List (SList (SNil))
2725
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..), SRequiresGradient (..))
2826
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
29-
import Torch.GraduallyTyped.Shape.Type (By (..), SBy (..), SSelectDim (..), SelectDim (..), Shape (..), SShape (..), pattern SNoName, pattern (:&:), SSize (..))
27+
import Torch.GraduallyTyped.Shape.Type (By (..), SBy (..), SSelectDim (..), SelectDim (..), Shape (..))
3028
import Torch.GraduallyTyped.Tensor.Indexing (IndexDims, IndexType (..), Indices (..), SIndexType (..), SIndices (..), (!))
3129
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (CatHListF, HasCat (..), SqueezeDimF, UnsqueezeF, sSqueezeDim, sUnsqueeze)
3230
import Torch.GraduallyTyped.Tensor.MathOperations.Comparison ((/=.), (==.))
3331
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (mul, mulScalar, sub, subScalar)
34-
import Torch.GraduallyTyped.Tensor.MathOperations.Reduction (ArgmaxF, all, argmax, sAllDim, maxAll, MaxAllCheckF)
35-
import Torch.GraduallyTyped.Tensor.Type (TensorSpec (..), SGetDataType (..), SGetDevice (..), SGetLayout (..), Tensor, TensorLike (..), sSetDataType, SGetShape (..), sCheckedShape, SGetDim)
36-
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead (..), mkTransformerInput)
32+
import Torch.GraduallyTyped.Tensor.MathOperations.Reduction (ArgmaxF, MaxAllCheckF, argmax, maxAll)
33+
import Torch.GraduallyTyped.Tensor.Type (SGetDataType (..), SGetDevice (..), SGetDim, SGetLayout (..), SGetShape (..), Tensor, TensorLike (..), sCheckedShape, sSetDataType)
3734
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
38-
import Torch.GraduallyTyped.Device (SDevice (..), SDeviceType (..))
39-
import Torch.GraduallyTyped.NN.Transformer.BART.Common (bartPadTokenId, bartEOSTokenId)
40-
import Torch.GraduallyTyped.NN.Transformer.BART.Base (bartBaseSpec)
41-
import Torch.GraduallyTyped.Tensor.Creation (sOnes)
42-
import Torch.GraduallyTyped.Layout (SLayout (..), SLayoutType (..))
43-
import Torch.GraduallyTyped.Random (Generator, sMkGenerator)
44-
import Torch.GraduallyTyped.NN.Type (SHasDropout (..))
4535
import Torch.HList (HList (HNil), pattern (:.))
46-
import qualified Torch.Internal.Class as ATen (Castable)
36+
import qualified Torch.Internal.Class as ATen
4737
import qualified Torch.Internal.Type as ATen
4838
import Prelude hiding (all)
49-
import Control.Monad ((<=<), (>=>))
50-
import Control.Monad.State (MonadState (..))
51-
import Torch.GraduallyTyped.Layout (Layout(Layout), LayoutType (Dense))
52-
import qualified Tokenizers
5339

5440
decode ::
5541
Monad m =>
@@ -59,10 +45,10 @@ decode ::
5945
m (x, s)
6046
decode f x s = do
6147
flip fix (x, s) $ \loop (x', s') -> do
62-
r <- f x' s'
63-
case r of
64-
Nothing -> pure (x', s')
65-
Just (x'', s'') -> loop (x'', s'')
48+
r <- f x' s'
49+
case r of
50+
Nothing -> pure (x', s')
51+
Just (x'', s'') -> loop (x'', s'')
6652

6753
sedtOutputToInput ::
6854
Monad m =>
@@ -117,7 +103,7 @@ greedyNextTokens ::
117103
SGetDevice logitsDevice,
118104
SGetLayout logitsLayout
119105
) =>
120-
Int ->
106+
Int ->
121107
Int ->
122108
Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape ->
123109
m (Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape)

0 commit comments

Comments
 (0)