XLA Structure


HloModule is the top-level unit in the HLO IR. It corresponds to a whole program. Running a module, from beginning to end, is the only way to run an XLA program.

A module contains one entry computation(HloComputation) which is like main() in a C program. The result of running the module is the result of running this computation.

A module contains some number of nested computations. Each nested computation is attached to an HloInstruction within some other computation. The meaning of the nested computation depends on the instruction it’s attached to.


  • HloModule(const string& name, HloModuleConfig config);
    Constructor without a versioned computation handle. This constructor should only be used for HloModules used outside of the XLA service (eg tests). The versioned handle is used by the service in the compilation cache. A default configuration is created for this module.

  • ```C++ class HloModule { public: // Adds an entry computation to the module. A module can only have one entry // computation. Returns a pointer to the newly added computation. HloComputation* AddEntryComputation( std::unique_ptr computation);

// Same as the AddEntryComputation function above but the module’s // entry_computation_layout is updated to match the layout of the new entry // computation. HloComputation* AddEntryComputationWithLayouts( std::unique_ptr computation);

// Replaces the current entry computation with another computation. // The new entry computation must be a computation that is already in the // module. void ReplaceEntryComputation(HloComputation* entry_computation);

// Adds an embedded computation to the module. HloComputation* AddEmbeddedComputation( std::unique_ptr computation);

// Removes an embedded computation. Status RemoveEmbeddedComputation(HloComputation* to_remove);

// Removes unused computations. Status RemoveUnusedComputations();

// Replaces all uses of computations that are keys of ‘replacements’ with // the corresponding values in ‘replacements’. Replaces the entry computation, // if applicable. // // This function iterates over all instructions in the module to find // computations to replace. We could speed it up by keeping track of users of // computations. void ReplaceComputations( const std::unordered_map<HloComputation, HloComputation>& replacements);

const string& name() const { return name_; } void set_name(string name) { name_ = std::move(name); }

// Returns a deep copy of this module including all computations. std::unique_ptr Clone(const string& suffix = "clone") const; std::unique_ptr Clone(const HloModuleConfig& config, const string& suffix = "clone") const;

// Performs a deep clone of the computation, by recursively cloning all // the called computations as well. If the clone context is specified, it // will be populated with the cloned object mappings. HloComputation* DeepCloneComputation(HloComputation* computation, HloCloneContext* context = nullptr);

// Return a pointer to the entry computation of the module. HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); return entry_computation_; }

bool has_entry_computation() const { return entry_computation_ != nullptr; }

// Returns the root instruction shape of entry computation. // // Precondition: entry_computation_ is not nullptr. const Shape& result_shape() const { CHECK_NE(nullptr, entry_computation_); return entry_computation()->root_instruction()->shape(); }

// Creates the ComputationLayout which describes the current status of the HLO // module entry computation. ComputationLayout compute_computation_layout() const { return ComputationLayout(entry_computation()->ComputeProgramShape(), /ignore_layouts=/false); }

ComputationLayout* mutable_entry_computation_layout() { return config_.mutable_entry_computation_layout(); }

const ComputationLayout& entry_computation_layout() const { return config_.entry_computation_layout(); }

// Generates a hash value of an HLO module. Hash considers // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO modules, // with respect to HloInstruction::Identical() method. uint64 Hash() const;

// Gets the computations in this module. // // Returns a view of HloComputations, so you can iterate over this in the // natural way: // // for (HloComputation c : module->computations()) { … } // tensorflow::gtl::iterator_range<UnwrappingIterator< std::vector<std::unique_ptr>::const_iterator>> computations() const { return {MakeUnwrappingIterator(computations_.begin()), MakeUnwrappingIterator(computations_.end())}; } tensorflow::gtl::iterator_range<UnwrappingIterator< std::vector<std::unique_ptr>::iterator>> computations() { return {MakeUnwrappingIterator(computations_.begin()), MakeUnwrappingIterator(computations_.end())}; }

// Returns the computation in this module that has the name name. Returns // null if there is no such computation. HloComputation* GetComputationWithName(absl::string_view name);

// Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); }

// Returns the mutable computation for the given index. HloComputation* mutable_computation(int64 idx) { CHECK(idx >= 0 && idx < computations_.size()); return computations_[idx].get(); }

// Gets the number of instructions in this module. int64 instruction_count() const;

// Deallocate removed instructions in each computation. void Cleanup() { for (auto& comp : computations_) { comp->Cleanup(); } }

// Compute and return a post order of all computations in the module. The sort // is defined like so: if computation A has an instruction which calls // computation B, then A will appear after B in the sort. std::vector<HloComputation*> MakeComputationPostOrder() const;

// Same as MakeComputationPostOrder() but sorting the computations by their // contents. The order is longer post order. std::vector<HloComputation*> MakeComputationSorted() const;

// Gets the computations in this module which aren’t for fusion nodes. // // Postcondition: All computations in the returned list have // !IsFusionComputation(). // // Note: Callers can and do rely on the return value here being a snapshot // of the module’s non-fusion computations – that is, it’s OK to add or // remove computations from a module while iterating over // MakeNonfusionComputations(). std::vector<HloComputation*> MakeNonfusionComputations() const;

// Same as MakeNonfusionComputations() but sorting computations by content. std::vector<HloComputation*> MakeNonfusionComputationsSorted() const;

const HloModuleConfig& config() const { return config_; } void set_config(const HloModuleConfig& config) { config_ = config; }

// Return a string representation of the module. // // (We express the default options using an overload rather than a default // param because gdb ignores default params, but does resolve overloads.) string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const;

// Convert an HloModule to or from a proto. HloModuleProto ToProto() const; static StatusOr<std::unique_ptr> CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config, bool prohibit_empty_literal = true);

// Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. static StatusOr CreateModuleConfigFromProto( const HloModuleProto& module, const DebugOptions& debug_options, const ExecutionOptions* execution_options = nullptr);

// Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. static StatusOr CreateModuleConfigFromShape( const ProgramShape& program_shape, const DebugOptions& debug_options, const ExecutionOptions* execution_options = nullptr);

// Outlines the given expression from the given computation. // instructions_to_outline contains the instructions that form the expression. // // Precondition: instructions in instructions_to_outline are in topological // order (root of outlined instructions last). TODO(jingyue): takes a set of // instructions and topologically sorts them. HloInstruction* OutlineExpressionFromComputation( absl::Span<HloInstruction* const> instructions_to_outline, const string& outlined_computation_name, HloComputation* computation);

// Returns a randomly generated uint64. uint64 RandomNew64() const;

// Returns the NameUniquer for uniquing instruction names in this module. NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; }

// Assign a new unique dense id for an instruction int NewUniqueInstructionId() { int result = next_unique_id_; next_unique_id_++; return result; }

// input_output_alias_config indicates the list of aliased buffers that are // expected from the module. HloInputOutputAliasConfig& input_output_alias_config() { return input_output_alias_config_; } const HloInputOutputAliasConfig& input_output_alias_config() const { return input_output_alias_config_; }

// DynamicParameterBinding holds the list of bindings that indicates which // parameter dimensions are dynamic and which parameters represent their // runtime value. DynamicParameterBinding& dynamic_parameter_binding() { return dynamic_parameter_binding_; } const DynamicParameterBinding& dynamic_parameter_binding() const { return dynamic_parameter_binding_; }

// Returns an id that is unique to this module across all modules created over // the lifetime of this process. int unique_id() const { return unique_id_; }

// Sets the schedule of the module to the given schedule. Status set_schedule(HloSchedule schedule);

// Clears the schedule of the module. void clear_schedule() { schedule_.reset(); }

// Returns true if the module has a schedule set. bool has_schedule() const { return schedule_.has_value(); }

// Returns the schedule of the module. CHECK fails if no schedule is set. const HloSchedule& schedule() const { return *schedule_; } HloSchedule& schedule() { return *schedule_; }

HloComputation* AddComputationAndUnifyNamesAndIds( std::unique_ptr computation, bool is_entry) { computation->ClearUniqueIdInternal(); for (auto* instruction : computation->instructions()) { instruction->ClearUniqueIdInternal(); } return AddComputationInternal(std::move(computation), is_entry, /*uniquify_identifiers=*/true, /*preserve_entry_layouts=*/true); }

Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const;

// Checks if this config has a list of entry parameters’ HLO shardings for // SPMD. bool has_spmd_parameters_shardings() const { return spmd_parameters_shardings_.has_value(); }

// Getter and setter for the list of entry parameters’ HLO shardings for SPMD. const std::vector& spmd_parameters_shardings() const { CHECK(spmd_parameters_shardings_.has_value()); return *spmd_parameters_shardings_; } void set_spmd_parameters_shardings( const std::vector& shardings) { spmd_parameters_shardings_ = shardings; }

// Checks if this config has the entry computation output’s HLO sharding for // SPMD. bool has_spmd_output_sharding() const { return spmd_output_sharding_.has_value(); }

// Getter and setter for the entry computation output’s HLO shardings for // SPMD. const HloSharding& spmd_output_sharding() const { CHECK(spmd_output_sharding_.has_value()); return *spmd_output_sharding_; } void set_spmd_output_sharding(const HloSharding& sharding) { spmd_output_sharding_ = sharding; }

// Add a program argument to be prefetched across programs. void AddCrossProgramPrefetch(int64 parameter, const ShapeIndex& index) { cross_program_prefetches_.emplace_back(parameter, index); }

// Get the list of program arguments to be prefetch across programs. const absl::Span<const std::pair<int64, ShapeIndex» CrossProgramPrefetches() const { return cross_program_prefetches_; }

private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, bool uniquify_identifiers, bool preserve_entry_layouts);

string name_; HloModuleConfig config_; HloComputation* entry_computation_ = nullptr; std::vector<std::unique_ptr> computations_;

// Random number generator engine to use when generating random numbers per // HloModule compilation. // TODO(b/25995601): Replace with better seed setting or dev/random for // where we don’t need deterministic execution. mutable std::mt1993764 rng{42}; mutable tensorflow::mutex rng_mutex_;

// Unique name generator for computation and instruction names, which are // unique per module. NameUniquer computation_name_uniquer_{/separator=/”.”}; NameUniquer instruction_name_uniquer_{/separator=/”.”}; int next_unique_id_ = 0;

// Used to keep track of the next unique module id that should be assigned. static std::atomic next_unique_module_id_; // A unique id to label modules with. int unique_id_;

// The HloSchedule of the module. The schedule if it exists contains a // sequential order of instructions for each non-fusion computation in the // module. absl::optional schedule_;

// alias_config indicates the alias information of input/output buffers that // are expected from the module. HloInputOutputAliasConfig input_output_alias_config_;

// Bindings for dynamic parameter mapping. DynamicParameterBinding dynamic_parameter_binding_;

// The HLO shardings of the entry computation’s parameters for // SPMD-partitioned programs. absl::optional<std::vector> spmd_parameters_shardings_;

// The HLO sharding of the entry computation’s output (root) for // SPMD-partitioned programs. absl::optional spmd_output_sharding_;

// Arguments to be prefetched across programs. std::vector<std::pair<int64, ShapeIndex» cross_program_prefetches_; };

# XLA Optimization Pass
In XLA, optimization passes are represented in `HloPassPipeline`. For example, the optimization for GPU backend is performed by `GpuCompiler::OptimizeHloModule`. You can see the pipeline below in `tensorflow/compiler/xla/service/gpu/gpu_compiler.cc`.

# CPU Backend Level Optimization
## Convolution canonicalization  
An HLO pass that canonicalizes the dimension numbers of all top-level convolutions in the given module. 

In order to hit the fast path of using Eigen's convolution implementation, a convolution's dimension numbers need to satisfy certain constraints (so called canonical convolutions). 

This pass expands non-canonical convolutions into reshapes and canonical convolutions, so that these non-canonical convolutions can run faster.

## Parallel task assigner  
ParallelTaskAssigner computes target parallel task counts for all HLOs in the module, then assigns parallel task counts to HLOs in the entry computation, or to HLOs in embedded computations invoked by (potentially nested) kWhile or kCall instructions. 

Each HLO which is assigned parallel task counts is outlined into its own embedded computation, which is compiled as a parallel compute function, and which is invoked from a kCall instruction that is lowered in codegen to a runtime parallel fork/join call.

# GPU Backend Level Optimization

## Alias passthrough params  
This pass aliases input and output buffers that are associated with a parameter that is passed through to the module root unmodified.

This pass assumes that parameters and the root use unnested shapes, which is the case for XLA:GPU.

This pass must run prior to copy insertion.

## cuBLAS GEMM pad for tensor cores  
Adds padding to dot operations to make them run faster on GPUs with tensor cores [https:devblogs.nvidia.com/programming-tensor-cores-cuda-9/](https:devblogs.nvidia.com/programming-tensor-cores-cuda-9/).

f16 dots are padded to have input/output shapes with dimensions that are multiples of 8, so that we can use tensor cores.

Don't run this pass on GPUs without tensor cores -- it will make them slower!
This pass depends on xla::DotDecomposer pass, so it should go strictly later.

## cuDNN batchnorm rewriter  
Rewrites BatchNorm HLOs into calls into cudnn where possible.

A call into cudnn for performing a batchnorm op is represented as a `CustomCall` HLO with `custom_call_target` equal to one of

- `kCudnnBatchNormForwardInferenceCallTarget`
- `kCudnnBatchNormForwardTrainingCallTarget`, or
- `kCudnnBatchNormBackwardCallTarget`.

A `CustomCall` created by this pass has the same operands corresponding batchnorm HLO, except the epsilon() and feature_index() properties of the batchnorm HLO are converted into proper operands, added to the end of the `CustomCall`'s operands list.

The inputs/outputs of the cudnn calls for `BatchNormTraining` and `BatchNormGrad` do not correspond exactly to the HLOs.  In particular, the training cudnn call returns `1/sqrt(variance + epsilon)`, while the HLO returns plain variance.  Similarly, the grad cudnn call expects `1/sqrt(variance + epsilon)` as input, whereas the HLO expects plain variance.

This pass adds HLOs in front of / behind the `CustomCall`s to fix up the inputs/outputs as appropriate, and we rely on the `AlgebraicSimplifier` to remove these where possible.

Currently batchnorm ops over F32s are converted into cudnn calls, so long as epsilon is not too small.  This pass leaves other batchnorm ops unmodified.

The GPU backend does not implement a lowering for the batchnorm HLOs -- it expects them to be lowered to cudnn calls via this pass or to HLO soup via `BatchNormRewriter`.

## cuDNN fused conv rewriter  
Rewrite the custom call targeting cudnnConvolutionForward to cudnnConvolutionBiasActivationForward by fusing applicable point-wise operations following forward convolution.  This transform must run after `cudnn_conv_rewriter`. It is straightforward for floating point convolutions: 


max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias))


cudnnConvolutionBiasActivationForward(x, w, bias, alpha1, alpha2, side)

Integer convolution requires additional patterns to match CuDNN semantics:
#1 from

cast(clamp<-128, 127>(conv(int8_x, int8_w)))


cudnnConvolutionForward(int8_x, int8_w)

or #2 from

cast(conv(int8_x, int8_w))


cudnnConvolutionForward(int8_x, int8_w)

or #3 from

cast(clamp<-128, 127>(max(0, alpha1 * cast(conv(int8_x, int8_w)) + alpha2 * cast(int8_side) + broadcast(bias)))


cudnnConvolutionBiasActivationForward(int8_x, int8_w, bias, alpha1, alpha2, int8_side)

or #4 from

max(0, alpha1 * cast(conv(int8_x, int8_w)) + alpha2 * float_side + broadcast(bias))


cudnnConvolutionBiasActivationForward(int8_x, int8_w, bias, alpha1, alpha2, float_side)

Rewrite the custom call targeting `cudnnConvolutionForward` to `cudnnConvolutionBiasActivationForward` by fusing applicable point-wise operations following forward convolution. This transform must run after `cudnn_conv_rewriter`.
It is straightforward for floating point convolutions:

max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias))


cudnnConvolutionBiasActivationForward(x, w, bias, alpha1, alpha2, side)

Integer convolution requires additional patterns to match CuDNN semantics:
#1 from

cast(clamp<-128, 127>(conv(int8_x, int8_w)))


cudnnConvolutionForward(int8_x, int8_w)

or #2 from

cast(conv(int8_x, int8_w))


cudnnConvolutionForward(int8_x, int8_w)

or #3 from

cast(clamp<-128, 127>(max(0, alpha1 * cast(conv(int8_x, int8_w)) + alpha2 * cast(int8_side) + broadcast(bias)))


cudnnConvolutionBiasActivationForward(int8_x, int8_w, bias, alpha1, alpha2, int8_side)

or #4 from

max(0, alpha1 * cast(conv(int8_x, int8_w)) + alpha2 * float_side + broadcast(bias))


cudnnConvolutionBiasActivationForward(int8_x, int8_w, bias, alpha1, alpha2, float_side)

## cuDNN pad for convolutions  
Two zero-paddings for CuDNN thunking are done in this transform: padding for tensor cores and padding for integer convolutions.  This transform also add slice instruction to remove unnecessary output features.

## cuSolver rewriter  
Rewrites Cholesky calls into `CustomCall` HLOs that call into cuSolver.

## Fusion merger  
An HLO pass that attempts to merge fusion instructions to reduce kernel launch overhead and improve data locality.

Fusion instructions are merged into their users if two conditions are met:

1) The flops_to_bytes ratio of the fusion instruction is below the threshold value of 1.0.
2) The result of merging the fusion instruction into its users would not increase bytes transferred.

## GEMM algorithm picker  

## GEMM rewriter  
cuBLAS GEMM in the most general form can run the following operation:

(kMultiply (kDot A B) alpha)
(kMultiply C beta))

where A, B, C are matrixes and `alpha` and `beta` are host constants. The additional requirement is that C has no other users (otherwise, it does not make sense to fuse it inside the custom call).

Both multiplication and addition can be avoided (equivalent to setting `alpha` to one and `beta` to zero).

This pass pattern-matches the most general form of this instruction (we assume transposes are already folded), and rewrites it into a custom call where (A, B, C) are three operands respectively, and `alpha` and `beta` are stored in the backend config.

## GPU conv algorithm picker  
Modifies `CustomCall`s to cudnn convolutions, choosing the best algorithm for each and adding explicit scratch space to the `CustomCall`s.

## GPU conv padding legalization  
An HLO pass that canonicalizes convolution instructions for GPU codegen. It inserts Pad instructions before Convolution instructions with uncanonicalized padding, so that they can be lowered to Cudnn/Miopen convolution.

## GPU conv rewriter  
Rewrites plain convolutions, backwards-filter convolutions, and backwards-input convolutions into `CustomCall` HLOs that call into Cudnn/Miopen. 

For integer convolution, it requires the following pattern:

conv<InputT=int32, ResultT=int32>( convert(int8_x), convert(int8_y))

We transform it to:

custom_call(int8_x, int8_y, target=cudnnForwardConvolution)

Note that this pattern is necessary but not sufficient to map convolutions to CuDNN. More patterns will be matched in `cudnn_fused_conv_rewriter`.

## GPU copy insertion  
Besides the modifications made by the generic `xla::CopyInsertion`, this GPU-specific copy insertion also materializes operands of library calls by inserting kCopy instructions.

## GPU sanitize constant names  
Sanitizes HLO instruction names for the GPU backend. Currently, it only replaces . and - in the HLO constant instruction names with _ to please the LLVM PTX backend.

## Horizontal fusion  
This optimization pass horizontally fuses computations for reducing kernel launch overhead while increasing kernel launch dims on GPU. The initial motivation of this horizontal fusion is due to the observation that the training optimizer phase (e.g., AdamOptimizer and L2Loss, etc.) typically has many small kernels as a result of applying the same formula on many training parameters (or variables in Tensorflow). Fusing these small kernels, hence, provides performance gain.

Theoretically speaking, we may implement a cycle detection algorithm to make sure no cycles are created after fusion. However, cycle detection check is somewhat cumbersome; also, we observe that naive horizontal fusion of arbitrary kernels may not be profitable due to control divergence and possible increase of memory bandwidth pressure due to uncoalesced memory accesses (note that horizontal fusion does not change the amount of memory read+written at all). In practice, a simple yet effective heuristic is used to avoid these issues while addressing the known beneficial cases. That is, we simply search for fusion candidates by looking for instructions whose outputs are all consumed by the same instruction. This catches the cases in the training optimizer phase, as the candidate instructions are typically consumed only by the ROOT tuple of the entry computation.

The following illustrates the mechanism of the horizontal fusion. Before fusion, there are two trivial kernels in the illustrating example. One has only a Mul op, while the other consists of only an Add op. Since they are only consumed by the same (ROOT) tuple instruction, horizontal fusion is triggered.

i0 i1 i2 i3 | | | | v v v v Mul Add | | v v (ROOT) tuple

We horizontally fuse them into the below pattern.

i0 i1 i2 i3 +++ (Slice) Input Fusion | | | | + v v v v + Mul Add + | | + v v + Reshape0 Reshape1 + | | + v v + Concatenate + | | + v v + Slice0 Slice1 +++ | | v v Reshape2 Reshape3 | | v v (ROOT) tuple

Note that this fusion style provides an important advantage that kernels of different shapes can be horizontally fused. The first pair of reshapes (i.e., Reshape0 and Reshape1) reshape the dims to 1 dimension, so that the outputs of the fused kernels can (always) be concatenated. The second pair of reshapes (Reshape2 and Reshape3) restore the original shapes to the output tensors.

No extra copies are introduced by the horizontal fusion. Besides Reshape2 and Reshape3, the other instructions are fused into an input fusion; the output dims of the concatenate will be used as the kernel launch dims. Instruction bitcasts can be used for Reshape2 and Reshape3 as long as the outputs of Mul and Add are row-major.

## Multi output fusion  
Multi-output fusion of sibling and producer-consumer instructions for the GPU backend.

## Reduction degenerate dim remover  
Enforces the invariant that reduction input and output have no degenerate (size 1) dimension. Since these dimensions are physically meaningless, they are removed using bitcasts.

For example,

f[1] out = reduce(f[100, 1, 1] input, dimensions={0, 1})


f[100] tmp1 = f[100] bitcast(f[100, 1, 1], input) f[] tmp2 = reduce(f[100] tmp1, dimensions={0}) f[1] out = f[] bitcast(tmp2)

## Reduction dimension grouper  
Groups adjacent (logically and physically) reduced dimensions in reduction input.

Precondition: ReductionLayoutNormalizer has been run (physical proximity and logical proximity become the same).

For example,

f[] out = reduce(f[10,20,30] input, dimensions={0,1,2})


f[600] tmp = f[600] bitcast(f[10,20,30] input) f[] out = reduce(f[600] tmp, dimensions={0})

## Reduction layout normalizer  
Enforces default (minor-to-major) layout on all reduction inputs.
Note that since reduction output can request a custom layout,
this pass only guarantees standard layout for the input.

For example,

f[20,30]{0,1} out = reduce(f[10,20,30]{2,0,1} input, dimensions={0})


f[20,10,30] tmp = f[20,10,30] bitcast(f[10,20,30]{2,0,1} input) f[20,30]{0,1} out = reduce(f[20,10,30]{2,1,0} tmp, dimensions={1})

## Reduction splitter  
Splits a reduce op into two consecutive reduce ops if
* the reduce dimensions are not contiguous and
* at least one reduce dimension is large (i.e. corresponds to a large input
shape dimension).

Reductions with non-contiguous dimensions are emitted as simple element-wise loops. This is inefficient when reducing large input shape dimensions. Splitting such reductions allows using more efficient reduction emitters.

This pass splits reduce ops into two consecutive reduce ops. Run it to a fixpoint to split reduce ops along multiple large dimensions.

Precondition: ReductionDimensionGrouper has been run and adjacent reduce dimentsions have been grouped. Reduction layouts have been normalized.

## Tree reduction rewriter  
Rewrites reductions in a way they can be implemented without atomics.

Rule application: rewrite a single HLO reduce operation into two.

- Case 1: Row reduction, batched dimension is present, larger than Z-tiling size.


f32[B] out = reduce(f32[A, B, C] input, dimensions={0, 2})


f32[A, B] tmp = reduce(f32[A, B, C] input, dimensions={2}) f32[B] out = reduce(f32[A, B] tmp, dimensions={0})

- Case 2: Row reduction

Let M be the thread tiling multiplied by the warp size.
We go from (assuming C > M):

f32[B] out = reduce(f32[A, B, C] input, dimensions={0, 2})


f32[A, B, P] padded = pad(input) Let P = ceil(C/M) * M. f32[A, B, Q, M] reshaped = bitcast(padded) Let Q = ceil(C/M) f32[B, Q] inner_reduce = reduce(reshaped, dimensions={0, 3}) f32[B] outer_reduce = reduce(inner_reduce, dimensions={1})

- Case 3: Column reduction

Let T be the tiling size for the column reduction.

We go from (assuming B > T):

f32[A, C] out = reduce(f32[A, B, C] input, dimensions={1})


f32[A, P, C] padded = pad(input) Let P = ceil(B/T) * T. f32[A, Q, T, C] reshaped = bitcast(padded) Let Q = ceil(B/T) f32[A, Q, C] inner_reduce = reduce(reshaped, dimensions={2}) f32[A, C] outer_reduce = reduce(inner_reduce, dimensions={1})

## Variadic op splitter  
Splits variadic ops with many operands into pieces such that we don't exceed the parameter space on the GPU. Currently only concatenate ops are split up.

# Frontend Level Optimization

## Algebraic simplifier
A pass which performs algebraic simplifications.

## All gather decomposer
AllGatherDecomposer is a pass which converts unsupported all-gathers into dynamic-update-slices and all-reduces.

## All reduce combiner
Combines small non-dependent AllReduce ops into larger combined AllReduce ops. A typical AllReduce implementation has a minimum latency-induced time for a AllReduce op so a single combined op can be more efficient than many small ones.

## All reduce simplifier
A pass that detects all-reduces whose inputs are already the same across replicas using the replication analysis, then replaces those all-reduces with local computations. E.g., a sum all-reduce on replicated input will be replaced by a multiply with the replica count.

## All reduce cross combiner
When the HLO graph contains a cross-module AllReduce (N separate AllReduce
ops that share the same channel_id for MPMD partitioning, or 1 AllReduce op
for SPMD partitioning), followed by some simple linear operations, followed
by a cross-replica AllReduce (also known as cross-replica sum, or CRS), we
can combine the CMAR and the CRAR, to use an efficient AllReduce
implementation that fully utilizes the interconnect bandwidth.

Such sequences appear in spatially partitioned models (either MPMD or SPMD).
This pass must run right after spatial partitioning, when the code is still
in a single HLO module.

The steps are:
1) Find CMARs followed by simple ops followed by CRARs.
2) Group CMARs by channel_id. They must all be rewritten. For SPMD
   partitioning, there will only be a single CMAR for each channel_id.
3) Prove that the CMAR patterns in each core produce the same result.
4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
   other operand by the number of spatial partitions.
5) Turn the CRAR into an all-core AllReduce.

The pass also handles the case where multiple CMARs lead to the same CRAR,
and eliminates all CMARs. This graph:
   |  X   CMAR_2   Z  |      \    / CMAR_1     +    \     /
CRAR ``` gets rewritten to: ```
      Z   num_partitions
       \  /
  Y    div
   \   /    X     +
\   /
  |  all-core AR ```

Batch dot simplification

Simplifies batch dot operations.

Normally these would live in the algebraic simplifier, but we want to run this to fixpoint (this pass reaches fixed point in one execution) before we run the DotDecomposer.

Batchnorm expander

A pass which rewrites batch norm operations into more operations. Breaking a big operation into smaller operations helps leverage our generic fusion logic.

bfloat16 conversion folding

A pass which folds F32 <-> BF16 conversions to their operands or users, when it is supported by the backend.

This pass follows the passed-in backend-specific BF16 support rules, but can introduce mixed precision in individual HLOs which breaks the assumption of some other HLO passes. So it should be used at the end of the HLO optimization pipeline followed by a DCE pass. If other passes are needed after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the changed made by this pass.

bfloat16 normalization

A pass which adds F32 <-> BF16 conversions for HLO instructions that do not support BF16 input/output or mixed precision, according to the passed-in backend-specific BF16 support rules.

bfloat16 mixed precision removal

A pass that unconditionally removes the mixed F32/BF16 uses in HLO instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike BFloat16Normalization, this pass does not use a backend-specific BFloat16Support, and does not change HLOs that have BF16 data if they do not use mixed precision; it removes mixed precision even if the backend supports it. This pass is used to make the HLO module valid for other HLO passes which do not support mixed precision.

bfloat16 propagation

HLO pass which reduces the precision of some HLO instructions to BF16 according to the backend-specific BFloat16Support rule provided by the caller.

This pass can be used to reduce instruction precision without affecting the numerical accuracy of the module, i.e., the final output of the module would be bitwise identical to that without this pass; this is possible if the backend already reduces precision to BF16 on some HLO instructions.

This pass will not modify the signature of a computation, unless it is a fusion computation or its only caller is a while.

!!! WARNING !!! This pass can introduce mixed precision in individual HLOs, which has two issues:

1) It does not guarantee to respect the passed-in BFloat16Support specification in terms of mixed precision, so the backend may not support an HLO that has mixed precision produced by this pass. To address this issue, run BFloat16Normalization with the same BFloat16Support after this pass.

2) In general, mixed precision may break the assumptions of some other HLO passes even if the specific backend supports the individual HLOs. Such assumptions include that there are no HLOs using mixed precision, or that the precision of an HLO’s output is determined by its inputs. It should be used at the end of the HLO optimization pipeline but before BFloat16ConversionFolding. If other passes are needed after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the changes made by this pass.

Call inliner

For every kCall operation in the main computation, we inline the body of the called function, and proceed recursively.

Conditional canonicalizer

Canonicalize output of conditionals, make non-tuple outputs into tuple with single element output. After this pass, all conditional instructions have tuple outputs.

Conditional code motion

HLO pass that moves identical ops in/out of conditional.

  • The definition of identical are the shape of the operands are identical and their properties are identical.
  • Only the identical ops that won’t share operands with other ops will be moved out of conditional.

Conditional simplifier

HLO pass that removes kConditional with a constant predicate, replacing them with their true or false computation as appropriate.

Conditional to select

A pass which transforms conditionals to selects in places where conditionals are legal, but not currently supported by the backends (e.g. inside kMap)

Convolution group converter

A pass which rewrites convolutions with feature_group_count > 1 into convolutions with feature_group_count = 1.

Copy insertion

Copy insertion is a legalization HLO pass which inserts copies (kCopy instructions) to eliminate several kinds of problems in the HLO module.

(1) Entry parameter or a constant live out of the entry computation. Entry computation arguments and constants have different lifetimes than the computation result and cannot share the same allocation. Parameters and constants live out of non-entry computations do not need copies.

(2) Different values which are simultaneously live and which must be held in the same buffer. This can occur in while bodies. Specifically, the while loop state (the arguments to the while instruction) is updated in-place and the update may clobber the value from the previous iteration before the previous value is dead. Computations called from kCall instructions do not need such copies because kCall has no update in-place semantics.

(3) The buffer set of the root instruction of the entry computation must be unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and InstructionAliasSet::IsDistinct return true.


A pass which replaces all fusion instructions with the equivalent un-fused instructions.


Pass which strips control dependencies from all instructions in the module. Creates an HloPassPipeline containing multiple HloPasses that can despecialize an optimized HloModule. This is useful to run an HloModule optimized for one specific platform on a different platform (undoing platform specific passes) with matching numerics for comparison.

Current despecialization passes are HloDescheduler, ControlDepRemover, Defuser and BFloat16MixedPrecisionRemoval.

Dot decomposer

DotDecomposer is a pass which converts dots into a canonical form where non-contracting and contracting dimensions are reshaped together and batch dimensions are the most major dimensions.

Dynamic index splitter

Convert R1 index operands to DynamicSlice and DynamicUpdateSlice ops into separate scalars.

Dynamic padder

With bounded shapes, only part of the shape contains effective data and the rest contains padded data, whose value can be anything depending on the source of the data. When a bounded shape is directly consumed by an instruction that collapses dimensions (reduce for example), the padding data would affect result of the instruction.

DynamicPadder uses DynamicDimensionInference to detect bounded shapes in a hlo module, it then inserts certain instructions to reset the padding into an identity value so that in doesn’t affect the result of subsequent instruction. For example, it’d reset the padding to 0 before a bounded shape is consumed by a reduce-sum.

Dynamic_padder removes dynamic shapes from the entry computation, and inserts custom calls (with dynamic shapes), which are lowered by specialized emitters: PadToStatic and SliceToDynamic.

Each instruction can have one of the three modes in supporting dynamic lowering.

Flatten call graph

Flattening associates each call site with a unique computation (for sequential calling contexts) This simplifies buffer assignment and points-to analysis (see b/36865746 for details).

HLO constant folding

A pass which performs constant folding in order to avoid unnecessary computation on constants.

HLO cse

A pass which performs common-subexpression elimination. Identical constants and identical instructions with the same operands are commoned. The pass iterates over the instructions in topological order which enables the pass to find arbitrarily large common expressions.

HLO dce

HLO pass which removes dead instructions from each computation in the module and removes dead computations from the module.

An instruction is dead if it is not reachable from the root. A computation is dead if it is not the entry computation of the module and it is not reachable from the entry computation.

This pass does not remove dead parameter instructions, as parameter instructions cannot be deleted.

HLO domain isolator

Domain isolation is the task of placing kDomain instructions between HLO instructions having different sharding. A kDomain instruction is essentially used to break an HLO graph edge connecting two instructions with different sharding. If a set of connected instructions have all the same sharding, no kDomain instruction will be placed.

HLO domain remover

Removes all the kDomain instructions of a given kind from the input module, and calls the normalizer to propagate the properties on the possibly new born instructions.

HLO domain verifier

Verifies that the domain instructions are consistent, and the each domain is surrounded by the same metadata.

Verify that the whole kDomain frontier bounding the instruction reach set, has matching metadata. A kDomain instruction has two sides of metadata, a user facing and an operand facing. A reachable instruction set can make contact with a kDomain instruction on a user facing side (the kDomain is operand of the instruction), or on a operand facing side (the kDomain is user of the instruction). And depending on the contact side, the proper metadata object (user_side_metadata() vs. operand_side_metadata()) needs to be used for consistency checks. Returns the DomainMetadata pointer which surrounds the domain, and represents the common metadata within such domain. If the returned DomainMetadata pointer is nullptr, the input domain had no kDomain boundary.

HLO element type converter

A pass that eliminates certain element types as the input or output of ops by inserting Convert ops. This allows a backend to support an element type while only actually implementing the Convert op for that element type. This is generally not the fastest approach, but it works.

HLO get dimension size rewriter

Pass to replace a kGetDimensionSize instruction with a hlo instruction representing the dynamic size if the dimension is dynamic, otherwise a constant instruction representing the static size.

HLO memory scheduler

A pass which schedules the HLO instructions in a module. The HloModule’s schedule field is set to the resulting HloSchedule using HloModule::set_schedule.

HLO trivial scheduler

A pass which produces a naive, but correct schedule. The schedule is produced using a DFS traversal of the graph with no attempt to minimize memory use.

HLO descheduler

A trivial pass which clears the schedule currently set on the HloModule. After this pass runs HloModule::has_schedule will return false.

HLO module dce

HLO pass which removes dead code from computations in the module using HloModule-scoped analysis (HloLivenessAnalysis).

Sweeps through live instructions which cross computation boundaries (kWhile), and removes code at dead shape indices.

HLO rematerialization

HLO pass which rematerializes instructions to reduce peak memory use, where memory use is defined as the total size of all live HLO instruction values. Parameters and constants are included in memory use estimates.

CSE will undo the effects of this optimization and should not be run after this pass. In general, this pass should be run very late, immediately before code generation.

HLO subcomputation unification

Unify subcomputations of a HloModule: if any computations are equal, choose one arbitrarily to use and delete the others.

HLO verifier

HLO pass that verifies invariants of HLO instructions for each computation in the module.

Indexed array analysis

A pass that prints all non-trivial results returned by IndexedArrayAnalysis. This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to unconditionally add to the regular HLO pass pipeline.

Instruction fusion

HLO pass which performs instruction fusion. Instructions are fused “vertically”, meaning producing instructions are fused into their consumers with the intent that the loops which compute their values will be fused in code generation. Derived classes define ShouldFuse method to select which instructions to fuse.

Layout assignment

HLO pass which assigns layouts to all instructions in the HLO module while satisfying all necessary invariants and minimizing cost.

Map inliner

A pass which performs map inlining. This replaces kMap instructions with their equivalent sequence of array operations. For example: map({X, Y}, add) -> add(X, Y)).

Memory space propagation

This is a legalization pass that propagates the memory space in the layout to the fusion computations.

Multi output fusion

This class implements the fusing of sibling fusion instructions that sharing common operands. It constructs the following associated data structures. (1) candidates_: stores the instruction and the set of instructions it can fuse to. (2) candidates_index_: maps instruction to id. (3) reachability_: reachability map in this computation. (4) all_fusion_candidates_: the vector of candidate instructions. (5) worklist_: a priority queue that contains pairs of instructions to be fused and their fusion profit scores.

Function Perform() applies the optimization. It picks up the most profitable pair in the worklist_, checks if it’s legal to fuse and fuses the pair. After fusion, it updates the associated structures such as reachability_, candidates_ and worklist_. Note that the reachability map is updated based on the original computation. This works because the reachability is monotonically increasing with instruction fusion.This class implements the fusing of sibling fusion instructions that sharing common operands. It constructs the following associated data structures. (1) candidates_: stores the instruction and the set of instructions it can fuse to. (2) candidates_index_: maps instruction to id. (3) reachability_: reachability map in this computation. (4) all_fusion_candidates_: the vector of candidate instructions. (5) worklist_: a priority queue that contains pairs of instructions to be fused and their fusion profit scores.

Function Perform() applies the optimization. It picks up the most profitable pair in the worklist_, checks if it’s legal to fuse and fuses the pair. After fusion, it updates the associated structures such as reachability_, candidates_ and worklist_. Note that the reachability map is updated based on the original computation. This works because the reachability is monotonically increasing with instruction fusion.

Op expand pass

This pass is an abstract superclass for passes that replace operations that match a pattern. It is intended to be subclassed, not used directly.

This pass is useful for legalizing HLO instructions that a particular backend does not support into other HLO instructions.

Cholesky expander

Convolution 4D expander

Gather expander

Logistic Expander

Rng bit generator expander

Rng expander

Stable sort expander

Triangular solve expander

Optimize input output buffer alias

This pass opportunistically finds input and output buffers that can be aliased, and writes the alias config into the HloModule.

The input and the output buffers can be in any shape, and each output buffer can alias with an input buffer with the same shape. Each input buffer may only alias with a single output buffer. For example, for the following parameter and the output buffers,

Parameters : { P1(f32[3]), P2(s32[3]), P3(f32[3,12]), P4(f32[16,12]), … } Outputs : { O1(s32[3]), O2(f32[3]), O3(f32[16,12]), … }

one potential aliasing would be (O1, P2), (O2, P1), (O3, P4), ..

Reshape mover

A pass which moves Reshapes and Transposes to let later passes combine them. This now only moves them outputward across elementwise ops all whose operands are equivalent Reshapes or Transposes, but in future could potentially move them inputward also.

Root instruction sinker

Given a scheduled HLO module, this pass sinks the ROOT of the instruction to the bottom of the non-fusion computations. To avoid dependency violations of moving the ROOT instruction, it creates a new ROOT instruction that looks like the following:

  • For tuple ROOT type: new_root = tuple(gte(old_root), gte(old_root), …)
  • For non-tuple ROOT type: new_root = bitcast(old_root)

Scatter expander

Sharding propagation

Propagates sharding information around the graph. HLOs that have shardings are kept as-is, those that do not have shardings are given shardings based on a simple local greedy heuristic.

Slice sinker

An HLO pass that sinks slice operations used by a group of elementwise operations and merges the group of elementwise operations.

Sort simplifier

HLO pass which removes unused operands from sort, where an unused operand is defined as an operand at some index ‘x’ at which the output is not used.

TopK rewriter

This pass pattern-matches soups of HLOs executing a TopK operation and replaces them with a TopK CustomCall when the given values are supported by the CustomCall and it is more efficient to use that implementation.

Transpose folding

HLO pass that folds transpose operators into Dot operators, where the Dot operator is implemented by a GEMM kernel that can transpose its inputs.

Tree reduction rewriter

Increase precision for the reduction operation by applying the reduce-window first.

E.g. suppose we want to reduce f32[1024] to a scalar. This pass first applies a reduce-window (with kSame padding) of size reduce_window_size, and then reduces the resulting array f32[32]. The rewrite is not applied if any of the reduced dimensions is smaller than the reduce_window_size.

Applying this pass until a fixed point performs a variant of pairwise summation (https:en.wikipedia.org/wiki/Pairwise_summation), which is guaranteed to have an asymptotically smaller error bound provided that intermediate roundoff errors are random and have random sign.

If this pass lowers the performance too much, the window size can always be increased to a larger value.

Tuple simplifier

A pass which simplifies patterns of Tuple and GetTupleElement instructions in the module.

While loop constant sinking

Sinks while loop invariant values that happen to be constants into the while loop body and conditional. This is probably not a win in isolation but may unlock further optimizations like constant folding.

state = (…, const, …) while (pred(state)) { (…, v, …) = state use(v) state = (…, v, …) }


state = (…, const, …) while (pred(state)) { (…, v, …) = state use(const) state = (…, v, …) }

Note that it leaves the v in place to keep that component of the state tuple trivially loop invariant. WhileLoopSimplifier will later get rid of v.

TODO(b/79121449): We should also sink broadcasts of constants.

While loop invariant code motion

HLO pass that rewrites while loops to hoist loop invariant instructions in the while body into the computation that contains the while instruction.

While loop simplifier

HLO pass that makes the following transformations on while loops:

  • A while loop with static trip count of 0 is deleted.

  • A while loop with static trip count of 1 is replaced by its body (sans loop).

  • Elements of a while loop’s tuple that the loop doesn’t use are removed from the tuple.

  • If the while loop’s parameter is a nested tuple, it’s flattened to a single-level tuple. This is good because it usually reduces the number of kTuple instructions, but also because it unlocks additional optimizations (e.g. removing unused loop parameters).

Flattening nested while loop tuples adds a whole mess of likely unnecessary kGetTupleElement and kTuple operations to the graph. We expect that tuple simplifier will be run afterwards.

While loop trip count annotator

Pass that annotates while loops with known trip counts.

The annotation is stored as a backend-config on the while loop node.

This pass should run after all passes that might semantically modify a while loop, e.g. by unrolling it. Otherwise, a loop could end up with a backend-config that doesn’t match its true trip-count.

This pass does some pattern-matching on loop bodies and conditions, so it should run after most HLO simplifications and before fusion and layout assignment, which make pattern matching much more difficult by e.g. introducing copy nodes.

Zero sized HLO elimination

HLO pass that replaces zero sized Hlos with a zero sized constant literal.